torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/source/conf.py +6 -4
- docs/source/docstring template.py +46 -0
- tests/test_identical.py +2 -3
- tests/test_opts.py +64 -50
- tests/test_vars.py +1 -0
- torchzero/core/module.py +138 -6
- torchzero/core/transform.py +158 -51
- torchzero/modules/__init__.py +3 -2
- torchzero/modules/clipping/clipping.py +114 -17
- torchzero/modules/clipping/ema_clipping.py +27 -13
- torchzero/modules/clipping/growth_clipping.py +8 -7
- torchzero/modules/experimental/__init__.py +22 -5
- torchzero/modules/experimental/absoap.py +5 -2
- torchzero/modules/experimental/adadam.py +8 -2
- torchzero/modules/experimental/adamY.py +8 -2
- torchzero/modules/experimental/adam_lambertw.py +149 -0
- torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
- torchzero/modules/experimental/adasoap.py +7 -2
- torchzero/modules/experimental/cosine.py +214 -0
- torchzero/modules/experimental/cubic_adam.py +97 -0
- torchzero/modules/{projections → experimental}/dct.py +11 -11
- torchzero/modules/experimental/eigendescent.py +4 -1
- torchzero/modules/experimental/etf.py +32 -9
- torchzero/modules/experimental/exp_adam.py +113 -0
- torchzero/modules/experimental/expanded_lbfgs.py +141 -0
- torchzero/modules/{projections → experimental}/fft.py +10 -10
- torchzero/modules/experimental/hnewton.py +85 -0
- torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
- torchzero/modules/experimental/newtonnewton.py +7 -3
- torchzero/modules/experimental/parabolic_search.py +220 -0
- torchzero/modules/experimental/reduce_outward_lr.py +4 -4
- torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
- torchzero/modules/experimental/subspace_preconditioners.py +11 -4
- torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
- torchzero/modules/functional.py +12 -2
- torchzero/modules/grad_approximation/fdm.py +30 -3
- torchzero/modules/grad_approximation/forward_gradient.py +13 -3
- torchzero/modules/grad_approximation/grad_approximator.py +51 -6
- torchzero/modules/grad_approximation/rfdm.py +285 -38
- torchzero/modules/higher_order/higher_order_newton.py +152 -89
- torchzero/modules/line_search/__init__.py +4 -4
- torchzero/modules/line_search/adaptive.py +99 -0
- torchzero/modules/line_search/backtracking.py +34 -9
- torchzero/modules/line_search/line_search.py +70 -12
- torchzero/modules/line_search/polynomial.py +233 -0
- torchzero/modules/line_search/scipy.py +2 -2
- torchzero/modules/line_search/strong_wolfe.py +34 -7
- torchzero/modules/misc/__init__.py +27 -0
- torchzero/modules/{ops → misc}/debug.py +24 -1
- torchzero/modules/misc/escape.py +60 -0
- torchzero/modules/misc/gradient_accumulation.py +70 -0
- torchzero/modules/misc/misc.py +316 -0
- torchzero/modules/misc/multistep.py +158 -0
- torchzero/modules/misc/regularization.py +171 -0
- torchzero/modules/{ops → misc}/split.py +29 -1
- torchzero/modules/{ops → misc}/switch.py +44 -3
- torchzero/modules/momentum/__init__.py +1 -1
- torchzero/modules/momentum/averaging.py +6 -6
- torchzero/modules/momentum/cautious.py +45 -8
- torchzero/modules/momentum/ema.py +7 -7
- torchzero/modules/momentum/experimental.py +2 -2
- torchzero/modules/momentum/matrix_momentum.py +90 -63
- torchzero/modules/momentum/momentum.py +2 -1
- torchzero/modules/ops/__init__.py +3 -31
- torchzero/modules/ops/accumulate.py +6 -10
- torchzero/modules/ops/binary.py +72 -26
- torchzero/modules/ops/multi.py +77 -16
- torchzero/modules/ops/reduce.py +15 -7
- torchzero/modules/ops/unary.py +29 -13
- torchzero/modules/ops/utility.py +20 -12
- torchzero/modules/optimizers/__init__.py +12 -3
- torchzero/modules/optimizers/adagrad.py +23 -13
- torchzero/modules/optimizers/adahessian.py +223 -0
- torchzero/modules/optimizers/adam.py +7 -6
- torchzero/modules/optimizers/adan.py +110 -0
- torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
- torchzero/modules/optimizers/esgd.py +171 -0
- torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
- torchzero/modules/optimizers/lion.py +1 -1
- torchzero/modules/optimizers/mars.py +91 -0
- torchzero/modules/optimizers/msam.py +186 -0
- torchzero/modules/optimizers/muon.py +30 -5
- torchzero/modules/optimizers/orthograd.py +1 -1
- torchzero/modules/optimizers/rmsprop.py +7 -4
- torchzero/modules/optimizers/rprop.py +42 -8
- torchzero/modules/optimizers/sam.py +163 -0
- torchzero/modules/optimizers/shampoo.py +39 -5
- torchzero/modules/optimizers/soap.py +29 -19
- torchzero/modules/optimizers/sophia_h.py +71 -14
- torchzero/modules/projections/__init__.py +2 -4
- torchzero/modules/projections/cast.py +51 -0
- torchzero/modules/projections/galore.py +3 -1
- torchzero/modules/projections/projection.py +188 -94
- torchzero/modules/quasi_newton/__init__.py +12 -2
- torchzero/modules/quasi_newton/cg.py +160 -59
- torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
- torchzero/modules/quasi_newton/lbfgs.py +154 -97
- torchzero/modules/quasi_newton/lsr1.py +101 -57
- torchzero/modules/quasi_newton/quasi_newton.py +863 -215
- torchzero/modules/quasi_newton/trust_region.py +397 -0
- torchzero/modules/second_order/__init__.py +2 -2
- torchzero/modules/second_order/newton.py +220 -41
- torchzero/modules/second_order/newton_cg.py +300 -11
- torchzero/modules/second_order/nystrom.py +104 -1
- torchzero/modules/smoothing/gaussian.py +34 -0
- torchzero/modules/smoothing/laplacian.py +14 -4
- torchzero/modules/step_size/__init__.py +2 -0
- torchzero/modules/step_size/adaptive.py +122 -0
- torchzero/modules/step_size/lr.py +154 -0
- torchzero/modules/weight_decay/__init__.py +1 -1
- torchzero/modules/weight_decay/weight_decay.py +89 -7
- torchzero/modules/wrappers/optim_wrapper.py +29 -1
- torchzero/optim/wrappers/directsearch.py +39 -2
- torchzero/optim/wrappers/fcmaes.py +21 -13
- torchzero/optim/wrappers/mads.py +5 -6
- torchzero/optim/wrappers/nevergrad.py +16 -1
- torchzero/optim/wrappers/optuna.py +1 -1
- torchzero/optim/wrappers/scipy.py +5 -3
- torchzero/utils/__init__.py +2 -2
- torchzero/utils/derivatives.py +3 -3
- torchzero/utils/linalg/__init__.py +1 -1
- torchzero/utils/linalg/solve.py +251 -12
- torchzero/utils/numberlist.py +2 -0
- torchzero/utils/python_tools.py +10 -0
- torchzero/utils/tensorlist.py +40 -28
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
- torchzero-0.3.11.dist-info/RECORD +159 -0
- torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
- torchzero/modules/experimental/soapy.py +0 -163
- torchzero/modules/experimental/structured_newton.py +0 -111
- torchzero/modules/lr/__init__.py +0 -2
- torchzero/modules/lr/adaptive.py +0 -93
- torchzero/modules/lr/lr.py +0 -63
- torchzero/modules/ops/misc.py +0 -418
- torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
- torchzero/modules/quasi_newton/olbfgs.py +0 -196
- torchzero-0.3.10.dist-info/RECORD +0 -139
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
- {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Previous(TensorwiseTransform):
|
|
13
|
+
"""Maintains an update from n steps back, for example if n=1, returns previous update"""
|
|
14
|
+
def __init__(self, n=1, target: Target = 'update'):
|
|
15
|
+
defaults = dict(n=n)
|
|
16
|
+
super().__init__(uses_grad=False, defaults=defaults, target=target)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@torch.no_grad
|
|
20
|
+
def apply_tensor(self, tensor, param, grad, loss, state, setting):
|
|
21
|
+
n = setting['n']
|
|
22
|
+
|
|
23
|
+
if 'history' not in state:
|
|
24
|
+
state['history'] = deque(maxlen=n+1)
|
|
25
|
+
|
|
26
|
+
state['history'].append(tensor)
|
|
27
|
+
|
|
28
|
+
return state['history'][0]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class LastDifference(Transform):
|
|
32
|
+
"""Outputs difference between past two updates."""
|
|
33
|
+
def __init__(self,target: Target = 'update'):
|
|
34
|
+
super().__init__({}, target=target)
|
|
35
|
+
|
|
36
|
+
@torch.no_grad
|
|
37
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
38
|
+
prev_tensors = unpack_states(states, tensors, 'prev_tensors') # initialized to 0
|
|
39
|
+
difference = torch._foreach_sub(tensors, prev_tensors)
|
|
40
|
+
for p, c in zip(prev_tensors, tensors): p.set_(c)
|
|
41
|
+
return difference
|
|
42
|
+
|
|
43
|
+
class LastGradDifference(Module):
|
|
44
|
+
"""Outputs difference between past two gradients."""
|
|
45
|
+
def __init__(self):
|
|
46
|
+
super().__init__({})
|
|
47
|
+
|
|
48
|
+
@torch.no_grad
|
|
49
|
+
def step(self, var):
|
|
50
|
+
grad = var.get_grad()
|
|
51
|
+
prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
|
|
52
|
+
difference = torch._foreach_sub(grad, prev_grad)
|
|
53
|
+
for p, c in zip(prev_grad, grad): p.copy_(c)
|
|
54
|
+
var.update = list(difference)
|
|
55
|
+
return var
|
|
56
|
+
|
|
57
|
+
class LastParamDifference(Module):
|
|
58
|
+
"""Outputs difference between past two parameters, which is the effective previous update."""
|
|
59
|
+
def __init__(self):
|
|
60
|
+
super().__init__({})
|
|
61
|
+
|
|
62
|
+
@torch.no_grad
|
|
63
|
+
def step(self, var):
|
|
64
|
+
params = var.params
|
|
65
|
+
prev_params = self.get_state(var.params, 'prev_params') # initialized to 0
|
|
66
|
+
difference = torch._foreach_sub(params, prev_params)
|
|
67
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
68
|
+
var.update = list(difference)
|
|
69
|
+
return var
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class LastProduct(Transform):
|
|
74
|
+
"""Outputs difference between past two updates."""
|
|
75
|
+
def __init__(self,target: Target = 'update'):
|
|
76
|
+
super().__init__({}, uses_grad=False, target=target)
|
|
77
|
+
|
|
78
|
+
@torch.no_grad
|
|
79
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
80
|
+
prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
|
|
81
|
+
prod = torch._foreach_mul(tensors, prev)
|
|
82
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
83
|
+
return prod
|
|
84
|
+
|
|
85
|
+
class LastRatio(Transform):
|
|
86
|
+
"""Outputs ratio between past two updates, the numerator is determined by :code:`numerator` argument."""
|
|
87
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
|
|
88
|
+
defaults = dict(numerator=numerator)
|
|
89
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
90
|
+
|
|
91
|
+
@torch.no_grad
|
|
92
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
93
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
94
|
+
numerator = settings[0]['numerator']
|
|
95
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
96
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
97
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
98
|
+
return ratio
|
|
99
|
+
|
|
100
|
+
class LastAbsoluteRatio(Transform):
|
|
101
|
+
"""Outputs ratio between absolute values of past two updates the numerator is determined by :code:`numerator` argument."""
|
|
102
|
+
def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
|
|
103
|
+
defaults = dict(numerator=numerator, eps=eps)
|
|
104
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
105
|
+
|
|
106
|
+
@torch.no_grad
|
|
107
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
108
|
+
prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
|
|
109
|
+
numerator = settings[0]['numerator']
|
|
110
|
+
eps = NumberList(s['eps'] for s in settings)
|
|
111
|
+
|
|
112
|
+
torch._foreach_abs_(tensors)
|
|
113
|
+
torch._foreach_clamp_min_(prev, eps)
|
|
114
|
+
|
|
115
|
+
if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
|
|
116
|
+
else: ratio = torch._foreach_div(prev, tensors)
|
|
117
|
+
for p, c in zip(prev, tensors): p.set_(c)
|
|
118
|
+
return ratio
|
|
119
|
+
|
|
120
|
+
class GradSign(Transform):
|
|
121
|
+
"""Copies gradient sign to update."""
|
|
122
|
+
def __init__(self, target: Target = 'update'):
|
|
123
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
124
|
+
|
|
125
|
+
@torch.no_grad
|
|
126
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
127
|
+
assert grads is not None
|
|
128
|
+
return [t.copysign_(g) for t,g in zip(tensors, grads)]
|
|
129
|
+
|
|
130
|
+
class UpdateSign(Transform):
|
|
131
|
+
"""Outputs gradient with sign copied from the update."""
|
|
132
|
+
def __init__(self, target: Target = 'update'):
|
|
133
|
+
super().__init__({}, uses_grad=True, target=target)
|
|
134
|
+
|
|
135
|
+
@torch.no_grad
|
|
136
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
137
|
+
assert grads is not None
|
|
138
|
+
return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
|
|
139
|
+
|
|
140
|
+
class GraftToGrad(Transform):
|
|
141
|
+
"""Grafts update to the gradient, that is update is rescaled to have the same norm as the gradient."""
|
|
142
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
143
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
144
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
145
|
+
|
|
146
|
+
@torch.no_grad
|
|
147
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
148
|
+
assert grads is not None
|
|
149
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
150
|
+
return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
151
|
+
|
|
152
|
+
class GraftGradToUpdate(Transform):
|
|
153
|
+
"""Outputs gradient grafted to update, that is gradient rescaled to have the same norm as the update."""
|
|
154
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
|
|
155
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
156
|
+
super().__init__(defaults, uses_grad=True, target=target)
|
|
157
|
+
|
|
158
|
+
@torch.no_grad
|
|
159
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
160
|
+
assert grads is not None
|
|
161
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
162
|
+
return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class GraftToParams(Transform):
|
|
166
|
+
"""Grafts update to the parameters, that is update is rescaled to have the same norm as the parameters, but no smaller than :code:`eps`."""
|
|
167
|
+
def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
|
|
168
|
+
defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
|
|
169
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
170
|
+
|
|
171
|
+
@torch.no_grad
|
|
172
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
173
|
+
tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
|
|
174
|
+
return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
|
|
175
|
+
|
|
176
|
+
class Relative(Transform):
|
|
177
|
+
"""Multiplies update by absolute parameter values to make it relative to their magnitude, :code:`min_value` is minimum allowed value to avoid getting stuck at 0."""
|
|
178
|
+
def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
|
|
179
|
+
defaults = dict(min_value=min_value)
|
|
180
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
181
|
+
|
|
182
|
+
@torch.no_grad
|
|
183
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
184
|
+
mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
|
|
185
|
+
torch._foreach_mul_(tensors, mul)
|
|
186
|
+
return tensors
|
|
187
|
+
|
|
188
|
+
class FillLoss(Module):
|
|
189
|
+
"""Outputs tensors filled with loss value times :code:`alpha`"""
|
|
190
|
+
def __init__(self, alpha: float = 1, backward: bool = True):
|
|
191
|
+
defaults = dict(alpha=alpha, backward=backward)
|
|
192
|
+
super().__init__(defaults)
|
|
193
|
+
|
|
194
|
+
@torch.no_grad
|
|
195
|
+
def step(self, var):
|
|
196
|
+
alpha = self.get_settings(var.params, 'alpha')
|
|
197
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
198
|
+
var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
|
|
199
|
+
return var
|
|
200
|
+
|
|
201
|
+
class MulByLoss(Module):
|
|
202
|
+
"""Multiplies update by loss times :code:`alpha`"""
|
|
203
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
204
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
205
|
+
super().__init__(defaults)
|
|
206
|
+
|
|
207
|
+
@torch.no_grad
|
|
208
|
+
def step(self, var):
|
|
209
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
210
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
211
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
212
|
+
torch._foreach_mul_(var.update, mul)
|
|
213
|
+
return var
|
|
214
|
+
|
|
215
|
+
class DivByLoss(Module):
|
|
216
|
+
"""Divides update by loss times :code:`alpha`"""
|
|
217
|
+
def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
|
|
218
|
+
defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
|
|
219
|
+
super().__init__(defaults)
|
|
220
|
+
|
|
221
|
+
@torch.no_grad
|
|
222
|
+
def step(self, var):
|
|
223
|
+
alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
|
|
224
|
+
loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
|
|
225
|
+
mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
|
|
226
|
+
torch._foreach_div_(var.update, mul)
|
|
227
|
+
return var
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class NoiseSign(Transform):
|
|
231
|
+
"""Outputs random tensors with sign copied from the update."""
|
|
232
|
+
def __init__(self, distribution:Distributions = 'normal', alpha = 1):
|
|
233
|
+
defaults = dict(distribution=distribution, alpha=alpha)
|
|
234
|
+
super().__init__(defaults, uses_grad=False)
|
|
235
|
+
|
|
236
|
+
@torch.no_grad
|
|
237
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
238
|
+
alpha = [s['alpha'] for s in settings]
|
|
239
|
+
distribution = self.settings[params[0]]['distribution']
|
|
240
|
+
return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
|
|
241
|
+
|
|
242
|
+
class HpuEstimate(Transform):
|
|
243
|
+
"""returns ``y/||s||``, where ``y`` is difference between current and previous update (gradient), ``s`` is difference between current and previous parameters. The returned tensors are a finite difference approximation to hessian times previous update."""
|
|
244
|
+
def __init__(self):
|
|
245
|
+
defaults = dict()
|
|
246
|
+
super().__init__(defaults, uses_grad=False)
|
|
247
|
+
|
|
248
|
+
def reset_for_online(self):
|
|
249
|
+
super().reset_for_online()
|
|
250
|
+
self.clear_state_keys('prev_params', 'prev_update')
|
|
251
|
+
|
|
252
|
+
@torch.no_grad
|
|
253
|
+
def update_tensors(self, tensors, params, grads, loss, states, settings):
|
|
254
|
+
prev_params, prev_update = self.get_state(params, 'prev_params', 'prev_update') # initialized to 0
|
|
255
|
+
s = torch._foreach_sub(params, prev_params)
|
|
256
|
+
y = torch._foreach_sub(tensors, prev_update)
|
|
257
|
+
for p, c in zip(prev_params, params): p.copy_(c)
|
|
258
|
+
for p, c in zip(prev_update, tensors): p.copy_(c)
|
|
259
|
+
torch._foreach_div_(y, torch.linalg.norm(torch.cat([t.ravel() for t in s])).clip(min=1e-8)) # pylint:disable=not-callable
|
|
260
|
+
self.store(params, ['s', 'y'], [s, y])
|
|
261
|
+
|
|
262
|
+
@torch.no_grad
|
|
263
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
264
|
+
return [self.state[p]['y'] for p in params]
|
|
265
|
+
|
|
266
|
+
class RandomHvp(Module):
|
|
267
|
+
"""Returns a hessian-vector product with a random vector"""
|
|
268
|
+
|
|
269
|
+
def __init__(
|
|
270
|
+
self,
|
|
271
|
+
n_samples: int = 1,
|
|
272
|
+
distribution: Distributions = "normal",
|
|
273
|
+
update_freq: int = 1,
|
|
274
|
+
hvp_method: Literal["autograd", "forward", "central"] = "autograd",
|
|
275
|
+
h=1e-3,
|
|
276
|
+
):
|
|
277
|
+
defaults = dict(n_samples=n_samples, distribution=distribution, hvp_method=hvp_method, h=h, update_freq=update_freq)
|
|
278
|
+
super().__init__(defaults)
|
|
279
|
+
|
|
280
|
+
@torch.no_grad
|
|
281
|
+
def step(self, var):
|
|
282
|
+
params = TensorList(var.params)
|
|
283
|
+
settings = self.settings[params[0]]
|
|
284
|
+
n_samples = settings['n_samples']
|
|
285
|
+
distribution = settings['distribution']
|
|
286
|
+
hvp_method = settings['hvp_method']
|
|
287
|
+
h = settings['h']
|
|
288
|
+
update_freq = settings['update_freq']
|
|
289
|
+
|
|
290
|
+
step = self.global_state.get('step', 0)
|
|
291
|
+
self.global_state['step'] = step + 1
|
|
292
|
+
|
|
293
|
+
D = None
|
|
294
|
+
if step % update_freq == 0:
|
|
295
|
+
|
|
296
|
+
rgrad = None
|
|
297
|
+
for i in range(n_samples):
|
|
298
|
+
u = params.sample_like(distribution=distribution)
|
|
299
|
+
|
|
300
|
+
Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
|
|
301
|
+
h=h, normalize=True, retain_grad=i < n_samples-1)
|
|
302
|
+
|
|
303
|
+
if D is None: D = Hvp
|
|
304
|
+
else: torch._foreach_add_(D, Hvp)
|
|
305
|
+
|
|
306
|
+
if n_samples > 1: torch._foreach_div_(D, n_samples)
|
|
307
|
+
if update_freq != 1:
|
|
308
|
+
assert D is not None
|
|
309
|
+
D_buf = self.get_state(params, "D", cls=TensorList)
|
|
310
|
+
D_buf.set_(D)
|
|
311
|
+
|
|
312
|
+
if D is None:
|
|
313
|
+
D = self.get_state(params, "D", cls=TensorList)
|
|
314
|
+
|
|
315
|
+
var.update = list(D)
|
|
316
|
+
return var
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from ...core import Chainable, Module, Var
|
|
6
|
+
from ...utils import TensorList
|
|
7
|
+
|
|
8
|
+
def _sequential_step(self: Module, var: Var, sequential: bool):
|
|
9
|
+
params = var.params
|
|
10
|
+
steps = self.settings[params[0]]['steps']
|
|
11
|
+
|
|
12
|
+
if sequential: modules = self.get_children_sequence() * steps
|
|
13
|
+
else: modules = [self.children['module']] * steps
|
|
14
|
+
|
|
15
|
+
if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
|
|
16
|
+
|
|
17
|
+
# store original params unless this is last module and can update params directly
|
|
18
|
+
params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
|
|
19
|
+
|
|
20
|
+
# first step - pass var as usual
|
|
21
|
+
var = modules[0].step(var)
|
|
22
|
+
new_var = var
|
|
23
|
+
|
|
24
|
+
# subsequent steps - update parameters and create new var
|
|
25
|
+
if len(modules) > 1:
|
|
26
|
+
for m in modules[1:]:
|
|
27
|
+
|
|
28
|
+
# update params
|
|
29
|
+
if (not new_var.skip_update):
|
|
30
|
+
if new_var.last_module_lrs is not None:
|
|
31
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
32
|
+
|
|
33
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
34
|
+
|
|
35
|
+
# create new var since we are at a new point, that means grad, update and loss will be None
|
|
36
|
+
new_var = Var(params=new_var.params, closure=new_var.closure,
|
|
37
|
+
model=new_var.model, current_step=new_var.current_step + 1)
|
|
38
|
+
|
|
39
|
+
# step
|
|
40
|
+
new_var = m.step(new_var)
|
|
41
|
+
|
|
42
|
+
# final parameter update
|
|
43
|
+
if (not new_var.skip_update):
|
|
44
|
+
if new_var.last_module_lrs is not None:
|
|
45
|
+
torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
|
|
46
|
+
|
|
47
|
+
torch._foreach_sub_(params, new_var.get_update())
|
|
48
|
+
|
|
49
|
+
# if last module, update is applied so return new var
|
|
50
|
+
if params_before_steps is None:
|
|
51
|
+
new_var.stop = True
|
|
52
|
+
new_var.skip_update = True
|
|
53
|
+
return new_var
|
|
54
|
+
|
|
55
|
+
# otherwise use parameter difference as update
|
|
56
|
+
var.update = list(torch._foreach_sub(params_before_steps, params))
|
|
57
|
+
for p, bef in zip(params, params_before_steps):
|
|
58
|
+
p.set_(bef) # pyright:ignore[reportArgumentType]
|
|
59
|
+
return var
|
|
60
|
+
|
|
61
|
+
class Multistep(Module):
|
|
62
|
+
"""Performs :code:`steps` inner steps with :code:`module` per each step.
|
|
63
|
+
|
|
64
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
65
|
+
def __init__(self, module: Chainable, steps: int):
|
|
66
|
+
defaults = dict(steps=steps)
|
|
67
|
+
super().__init__(defaults)
|
|
68
|
+
self.set_child('module', module)
|
|
69
|
+
|
|
70
|
+
@torch.no_grad
|
|
71
|
+
def step(self, var):
|
|
72
|
+
return _sequential_step(self, var, sequential=False)
|
|
73
|
+
|
|
74
|
+
class Sequential(Module):
|
|
75
|
+
"""On each step, this sequentially steps with :code:`modules` :code:`steps` times.
|
|
76
|
+
|
|
77
|
+
The update is taken to be the parameter difference between parameters before and after the inner loop."""
|
|
78
|
+
def __init__(self, modules: Iterable[Chainable], steps: int=1):
|
|
79
|
+
defaults = dict(steps=steps)
|
|
80
|
+
super().__init__(defaults)
|
|
81
|
+
self.set_children_sequence(modules)
|
|
82
|
+
|
|
83
|
+
@torch.no_grad
|
|
84
|
+
def step(self, var):
|
|
85
|
+
return _sequential_step(self, var, sequential=True)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class NegateOnLossIncrease(Module):
|
|
89
|
+
"""Uses an extra forward pass to evaluate loss at :code:`parameters+update`,
|
|
90
|
+
if loss is larger than at :code:`parameters`,
|
|
91
|
+
the update is set to 0 if :code:`backtrack=False` and to :code:`-update` otherwise"""
|
|
92
|
+
def __init__(self, backtrack=False):
|
|
93
|
+
defaults = dict(backtrack=backtrack)
|
|
94
|
+
super().__init__(defaults=defaults)
|
|
95
|
+
|
|
96
|
+
@torch.no_grad
|
|
97
|
+
def step(self, var):
|
|
98
|
+
closure = var.closure
|
|
99
|
+
if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
|
|
100
|
+
backtrack = self.settings[var.params[0]]['backtrack']
|
|
101
|
+
|
|
102
|
+
update = var.get_update()
|
|
103
|
+
f_0 = var.get_loss(backward=False)
|
|
104
|
+
|
|
105
|
+
torch._foreach_sub_(var.params, update)
|
|
106
|
+
f_1 = closure(False)
|
|
107
|
+
|
|
108
|
+
if f_1 <= f_0:
|
|
109
|
+
if var.is_last and var.last_module_lrs is None:
|
|
110
|
+
var.stop = True
|
|
111
|
+
var.skip_update = True
|
|
112
|
+
return var
|
|
113
|
+
|
|
114
|
+
torch._foreach_add_(var.params, update)
|
|
115
|
+
return var
|
|
116
|
+
|
|
117
|
+
torch._foreach_add_(var.params, update)
|
|
118
|
+
if backtrack:
|
|
119
|
+
torch._foreach_neg_(var.update)
|
|
120
|
+
else:
|
|
121
|
+
torch._foreach_zero_(var.update)
|
|
122
|
+
return var
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Online(Module):
|
|
126
|
+
"""Allows certain modules to be used for mini-batch optimization."""
|
|
127
|
+
def __init__(self, module: Chainable,):
|
|
128
|
+
super().__init__()
|
|
129
|
+
|
|
130
|
+
self.set_child('module', module)
|
|
131
|
+
|
|
132
|
+
@torch.no_grad
|
|
133
|
+
def step(self, var):
|
|
134
|
+
closure = var.closure
|
|
135
|
+
if closure is None: raise ValueError("Closure must be passed for Online")
|
|
136
|
+
step = self.global_state.get('step', 0) + 1
|
|
137
|
+
self.global_state['step'] = step
|
|
138
|
+
params = TensorList(var.params)
|
|
139
|
+
p_cur = params.clone()
|
|
140
|
+
p_prev = self.get_state(params, 'p_prev', cls=TensorList)
|
|
141
|
+
module = self.children['module']
|
|
142
|
+
|
|
143
|
+
if step == 1:
|
|
144
|
+
var = module.step(var.clone(clone_update=False))
|
|
145
|
+
|
|
146
|
+
p_prev.copy_(params)
|
|
147
|
+
return var
|
|
148
|
+
|
|
149
|
+
# restore previous params
|
|
150
|
+
var_prev = Var(params=params, closure=closure, model=var.model, current_step=var.current_step)
|
|
151
|
+
params.set_(p_prev)
|
|
152
|
+
module.reset_for_online()
|
|
153
|
+
module.update(var_prev)
|
|
154
|
+
|
|
155
|
+
# restore current params
|
|
156
|
+
params.set_(p_cur)
|
|
157
|
+
p_prev.copy_(params)
|
|
158
|
+
return module.step(var.clone(clone_update=False))
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from collections import deque
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from operator import itemgetter
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
|
|
9
|
+
from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Dropout(Transform):
|
|
13
|
+
"""Applies dropout to the update.
|
|
14
|
+
|
|
15
|
+
For each weight the update to that weight has :code:`p` probability to be set to 0.
|
|
16
|
+
This can be used to implement gradient dropout or update dropout depending on placement.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
p (float, optional): probability that update for a weight is replaced with 0. Defaults to 0.5.
|
|
20
|
+
graft (bool, optional):
|
|
21
|
+
if True, update after dropout is rescaled to have the same norm as before dropout. Defaults to False.
|
|
22
|
+
target (Target, optional): what to set on var, refer to documentation. Defaults to 'update'.
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
Gradient dropout.
|
|
27
|
+
|
|
28
|
+
.. code-block:: python
|
|
29
|
+
|
|
30
|
+
opt = tz.Modular(
|
|
31
|
+
model.parameters(),
|
|
32
|
+
tz.m.Dropout(0.5),
|
|
33
|
+
tz.m.Adam(),
|
|
34
|
+
tz.m.LR(1e-3)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
Update dropout.
|
|
38
|
+
|
|
39
|
+
.. code-block:: python
|
|
40
|
+
|
|
41
|
+
opt = tz.Modular(
|
|
42
|
+
model.parameters(),
|
|
43
|
+
tz.m.Adam(),
|
|
44
|
+
tz.m.Dropout(0.5),
|
|
45
|
+
tz.m.LR(1e-3)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
|
|
50
|
+
defaults = dict(p=p, graft=graft)
|
|
51
|
+
super().__init__(defaults, uses_grad=False, target=target)
|
|
52
|
+
|
|
53
|
+
@torch.no_grad
|
|
54
|
+
def apply_tensors(self, tensors, params, grads, loss, states, settings):
|
|
55
|
+
tensors = TensorList(tensors)
|
|
56
|
+
p = NumberList(s['p'] for s in settings)
|
|
57
|
+
graft = settings[0]['graft']
|
|
58
|
+
|
|
59
|
+
if graft:
|
|
60
|
+
target_norm = tensors.global_vector_norm()
|
|
61
|
+
tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
62
|
+
return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
|
|
63
|
+
|
|
64
|
+
return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
|
|
65
|
+
|
|
66
|
+
def _bernoulli_like(tensor, p = 0.5, generator = None):
|
|
67
|
+
"""p is probability of a 1, other values will be 0."""
|
|
68
|
+
return torch.bernoulli(torch.full_like(tensor, p), generator = generator)
|
|
69
|
+
|
|
70
|
+
class WeightDropout(Module):
|
|
71
|
+
"""
|
|
72
|
+
Changes the closure so that it evaluates loss and gradients with random weights replaced with 0.
|
|
73
|
+
|
|
74
|
+
Dropout can be disabled for a parameter by setting :code:`use_dropout=False` in corresponding parameter group.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
p (float, optional): probability that any weight is replaced with 0. Defaults to 0.5.
|
|
78
|
+
graft (bool, optional):
|
|
79
|
+
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
80
|
+
"""
|
|
81
|
+
def __init__(self, p: float = 0.5, graft: bool = True):
|
|
82
|
+
defaults = dict(p=p, graft=graft, use_dropout=True)
|
|
83
|
+
super().__init__(defaults)
|
|
84
|
+
|
|
85
|
+
@torch.no_grad
|
|
86
|
+
def step(self, var):
|
|
87
|
+
closure = var.closure
|
|
88
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
89
|
+
params = TensorList(var.params)
|
|
90
|
+
p = NumberList(self.settings[p]['p'] for p in params)
|
|
91
|
+
|
|
92
|
+
# create masks
|
|
93
|
+
mask = []
|
|
94
|
+
for p, m in zip(params, mask):
|
|
95
|
+
prob = self.settings[p]['p']
|
|
96
|
+
use_dropout = self.settings[p]['use_dropout']
|
|
97
|
+
if use_dropout: mask.append(_bernoulli_like(p, prob))
|
|
98
|
+
else: mask.append(torch.ones_like(p))
|
|
99
|
+
|
|
100
|
+
@torch.no_grad
|
|
101
|
+
def dropout_closure(backward=True):
|
|
102
|
+
orig_params = params.clone()
|
|
103
|
+
params.mul_(mask)
|
|
104
|
+
if backward:
|
|
105
|
+
with torch.enable_grad(): loss = closure()
|
|
106
|
+
else:
|
|
107
|
+
loss = closure(False)
|
|
108
|
+
params.copy_(orig_params)
|
|
109
|
+
return loss
|
|
110
|
+
|
|
111
|
+
var.closure = dropout_closure
|
|
112
|
+
return var
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class PerturbWeights(Module):
|
|
116
|
+
"""
|
|
117
|
+
Changes the closure so that it evaluates loss and gradients at weights perturbed by a random perturbation.
|
|
118
|
+
|
|
119
|
+
Can be disabled for a parameter by setting :code:`perturb=False` in corresponding parameter group.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
alpha (float, optional): multiplier for perturbation magnitude. Defaults to 0.1.
|
|
123
|
+
relative (bool, optional): whether to multiply perturbation by mean absolute value of the parameter. Defaults to True.
|
|
124
|
+
graft (bool, optional):
|
|
125
|
+
if True, parameters after dropout are rescaled to have the same norm as before dropout. Defaults to False.
|
|
126
|
+
"""
|
|
127
|
+
def __init__(self, alpha: float = 0.1, relative:bool=True, distribution:Distributions = 'normal'):
|
|
128
|
+
defaults = dict(alpha=alpha, relative=relative, distribution=distribution, perturb=True)
|
|
129
|
+
super().__init__(defaults)
|
|
130
|
+
|
|
131
|
+
@torch.no_grad
|
|
132
|
+
def step(self, var):
|
|
133
|
+
closure = var.closure
|
|
134
|
+
if closure is None: raise RuntimeError('WeightDropout requires closure')
|
|
135
|
+
params = TensorList(var.params)
|
|
136
|
+
|
|
137
|
+
# create perturbations
|
|
138
|
+
perts = []
|
|
139
|
+
for p in params:
|
|
140
|
+
settings = self.settings[p]
|
|
141
|
+
if not settings['perturb']:
|
|
142
|
+
perts.append(torch.zeros_like(p))
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
alpha = settings['alpha']
|
|
146
|
+
if settings['relative']:
|
|
147
|
+
alpha *= p.abs().mean()
|
|
148
|
+
|
|
149
|
+
distribution = self.settings[p]['distribution'].lower()
|
|
150
|
+
if distribution in ('normal', 'gaussian'):
|
|
151
|
+
perts.append(torch.randn_like(p).mul_(alpha))
|
|
152
|
+
elif distribution == 'uniform':
|
|
153
|
+
perts.append(torch.empty_like(p).uniform_(-alpha,alpha))
|
|
154
|
+
elif distribution == 'sphere':
|
|
155
|
+
r = torch.randn_like(p)
|
|
156
|
+
perts.append((r * alpha) / torch.linalg.vector_norm(r)) # pylint:disable=not-callable
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(distribution)
|
|
159
|
+
|
|
160
|
+
@torch.no_grad
|
|
161
|
+
def perturbed_closure(backward=True):
|
|
162
|
+
params.add_(perts)
|
|
163
|
+
if backward:
|
|
164
|
+
with torch.enable_grad(): loss = closure()
|
|
165
|
+
else:
|
|
166
|
+
loss = closure(False)
|
|
167
|
+
params.sub_(perts)
|
|
168
|
+
return loss
|
|
169
|
+
|
|
170
|
+
var.closure = perturbed_closure
|
|
171
|
+
return var
|
|
@@ -45,7 +45,35 @@ def _split(
|
|
|
45
45
|
return var
|
|
46
46
|
|
|
47
47
|
class Split(Module):
|
|
48
|
-
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
|
|
48
|
+
"""Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
filter (Callable[[torch.Tensor], bool]): a function that takes in a parameter tensor and returns a boolean value.
|
|
52
|
+
true (Chainable | None): modules that are applied to tensors where :code:`filter` returned True.
|
|
53
|
+
false (Chainable | None): modules that are applied to tensors where :code:`filter` returned False.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
standard Muon with Adam fallback
|
|
57
|
+
|
|
58
|
+
.. code-block:: python
|
|
59
|
+
|
|
60
|
+
opt = tz.Modular(
|
|
61
|
+
model.head.parameters(),
|
|
62
|
+
tz.m.Split(
|
|
63
|
+
# apply muon only to 2D+ parameters
|
|
64
|
+
filter = lambda t: t.ndim >= 2,
|
|
65
|
+
true = [
|
|
66
|
+
tz.m.HeavyBall(),
|
|
67
|
+
tz.m.Orthogonalize(),
|
|
68
|
+
tz.m.LR(1e-2),
|
|
69
|
+
],
|
|
70
|
+
false = tz.m.Adam()
|
|
71
|
+
),
|
|
72
|
+
tz.m.LR(1e-2)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
"""
|
|
49
77
|
def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
|
|
50
78
|
defaults = dict(filter=filter)
|
|
51
79
|
super().__init__(defaults)
|