torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,418 +0,0 @@
1
- from collections import deque
2
- from collections.abc import Iterable
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
- from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
-
11
-
12
- class Previous(TensorwiseTransform):
13
- """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
- def __init__(self, n=1, target: Target = 'update'):
15
- defaults = dict(n=n)
16
- super().__init__(uses_grad=False, defaults=defaults, target=target)
17
-
18
-
19
- @torch.no_grad
20
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
21
- n = settings['n']
22
-
23
- if 'history' not in state:
24
- state['history'] = deque(maxlen=n+1)
25
-
26
- state['history'].append(tensor)
27
-
28
- return state['history'][0]
29
-
30
-
31
- class LastDifference(Transform):
32
- """Difference between past two updates."""
33
- def __init__(self,target: Target = 'update'):
34
- super().__init__({}, uses_grad=False, target=target)
35
-
36
- @torch.no_grad
37
- def apply(self, tensors, params, grads, loss, states, settings):
38
- prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
39
- difference = torch._foreach_sub(tensors, prev)
40
- for p, c in zip(prev, tensors): p.set_(c)
41
- return difference
42
-
43
- class LastGradDifference(Module):
44
- """Difference between past two grads."""
45
- def __init__(self):
46
- super().__init__({})
47
-
48
- @torch.no_grad
49
- def step(self, var):
50
- grad = var.get_grad()
51
- prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
52
- difference = torch._foreach_sub(grad, prev_grad)
53
- for p, c in zip(prev_grad, grad): p.set_(c)
54
- var.update = list(difference)
55
- return var
56
-
57
-
58
- class LastProduct(Transform):
59
- """Difference between past two updates."""
60
- def __init__(self,target: Target = 'update'):
61
- super().__init__({}, uses_grad=False, target=target)
62
-
63
- @torch.no_grad
64
- def apply(self, tensors, params, grads, loss, states, settings):
65
- prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
66
- prod = torch._foreach_mul(tensors, prev)
67
- for p, c in zip(prev, tensors): p.set_(c)
68
- return prod
69
-
70
- class LastRatio(Transform):
71
- """Ratio between past two updates."""
72
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
73
- defaults = dict(numerator=numerator)
74
- super().__init__(defaults, uses_grad=False, target=target)
75
-
76
- @torch.no_grad
77
- def apply(self, tensors, params, grads, loss, states, settings):
78
- prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
79
- numerator = settings[0]['numerator']
80
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
81
- else: ratio = torch._foreach_div(prev, tensors)
82
- for p, c in zip(prev, tensors): p.set_(c)
83
- return ratio
84
-
85
- class LastAbsoluteRatio(Transform):
86
- """Ratio between absolute values of past two updates."""
87
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
88
- defaults = dict(numerator=numerator, eps=eps)
89
- super().__init__(defaults, uses_grad=False, target=target)
90
-
91
- @torch.no_grad
92
- def apply(self, tensors, params, grads, loss, states, settings):
93
- prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
94
- numerator = settings[0]['numerator']
95
- eps = NumberList(s['eps'] for s in settings)
96
-
97
- torch._foreach_abs_(tensors)
98
- torch._foreach_clamp_min_(prev, eps)
99
-
100
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
101
- else: ratio = torch._foreach_div(prev, tensors)
102
- for p, c in zip(prev, tensors): p.set_(c)
103
- return ratio
104
-
105
- class GradSign(Transform):
106
- """copy gradient sign to update."""
107
- def __init__(self, target: Target = 'update'):
108
- super().__init__({}, uses_grad=True, target=target)
109
-
110
- @torch.no_grad
111
- def apply(self, tensors, params, grads, loss, states, settings):
112
- assert grads is not None
113
- return [t.copysign_(g) for t,g in zip(tensors, grads)]
114
-
115
- class UpdateSign(Transform):
116
- """use per-weight magnitudes from grad while using sign from update."""
117
- def __init__(self, target: Target = 'update'):
118
- super().__init__({}, uses_grad=True, target=target)
119
-
120
- @torch.no_grad
121
- def apply(self, tensors, params, grads, loss, states, settings):
122
- assert grads is not None
123
- return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
124
-
125
- class GraftToGrad(Transform):
126
- """use gradient norm and update direction."""
127
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
128
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
129
- super().__init__(defaults, uses_grad=True, target=target)
130
-
131
- @torch.no_grad
132
- def apply(self, tensors, params, grads, loss, states, settings):
133
- assert grads is not None
134
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
135
- return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
136
-
137
- class GraftGradToUpdate(Transform):
138
- """use update norm and gradient direction."""
139
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
140
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
141
- super().__init__(defaults, uses_grad=True, target=target)
142
-
143
- @torch.no_grad
144
- def apply(self, tensors, params, grads, loss, states, settings):
145
- assert grads is not None
146
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
147
- return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
148
-
149
-
150
- class GraftToParams(Transform):
151
- """makes update norm be set to parameter norm, but norm won't go below eps"""
152
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
153
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
154
- super().__init__(defaults, uses_grad=False, target=target)
155
-
156
- @torch.no_grad
157
- def apply(self, tensors, params, grads, loss, states, settings):
158
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
159
- return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
160
-
161
- class Relative(Transform):
162
- """multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
163
- def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
164
- defaults = dict(min_value=min_value)
165
- super().__init__(defaults, uses_grad=False, target=target)
166
-
167
- @torch.no_grad
168
- def apply(self, tensors, params, grads, loss, states, settings):
169
- mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
170
- torch._foreach_mul_(tensors, mul)
171
- return tensors
172
-
173
- class FillLoss(Module):
174
- """makes tensors filled with loss value times alpha"""
175
- def __init__(self, alpha: float = 1, backward: bool = True):
176
- defaults = dict(alpha=alpha, backward=backward)
177
- super().__init__(defaults)
178
-
179
- @torch.no_grad
180
- def step(self, var):
181
- alpha = self.get_settings(var.params, 'alpha')
182
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
183
- var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
184
- return var
185
-
186
- class MulByLoss(Module):
187
- """multiplies update by loss times alpha"""
188
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
189
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
190
- super().__init__(defaults)
191
-
192
- @torch.no_grad
193
- def step(self, var):
194
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
195
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
196
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
197
- torch._foreach_mul_(var.update, mul)
198
- return var
199
-
200
- class DivByLoss(Module):
201
- """divides update by loss times alpha"""
202
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
203
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
204
- super().__init__(defaults)
205
-
206
- @torch.no_grad
207
- def step(self, var):
208
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
209
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
210
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
211
- torch._foreach_div_(var.update, mul)
212
- return var
213
-
214
-
215
-
216
- def _sequential_step(self: Module, var: Var, sequential: bool):
217
- params = var.params
218
- steps = self.settings[params[0]]['steps']
219
-
220
- if sequential: modules = self.get_children_sequence()
221
- else: modules = [self.children['module']] * steps
222
-
223
- if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
224
-
225
- # store original params unless this is last module and can update params directly
226
- params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
227
-
228
- # first step - pass var as usual
229
- var = modules[0].step(var)
230
- new_var = var
231
-
232
- # subsequent steps - update parameters and create new var
233
- if len(modules) > 1:
234
- for m in modules[1:]:
235
-
236
- # update params
237
- if (not new_var.skip_update):
238
- if new_var.last_module_lrs is not None:
239
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
240
-
241
- torch._foreach_sub_(params, new_var.get_update())
242
-
243
- # create new var since we are at a new point, that means grad, update and loss will be None
244
- new_var = Var(params=new_var.params, closure=new_var.closure,
245
- model=new_var.model, current_step=new_var.current_step + 1)
246
-
247
- # step
248
- new_var = m.step(new_var)
249
-
250
- # final parameter update
251
- if (not new_var.skip_update):
252
- if new_var.last_module_lrs is not None:
253
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
254
-
255
- torch._foreach_sub_(params, new_var.get_update())
256
-
257
- # if last module, update is applied so return new var
258
- if params_before_steps is None:
259
- new_var.stop = True
260
- new_var.skip_update = True
261
- return new_var
262
-
263
- # otherwise use parameter difference as update
264
- var.update = list(torch._foreach_sub(params_before_steps, params))
265
- for p, bef in zip(params, params_before_steps):
266
- p.set_(bef) # pyright:ignore[reportArgumentType]
267
- return var
268
-
269
- class Multistep(Module):
270
- def __init__(self, module: Chainable, steps: int):
271
- defaults = dict(steps=steps)
272
- super().__init__(defaults)
273
- self.set_child('module', module)
274
-
275
- @torch.no_grad
276
- def step(self, var):
277
- return _sequential_step(self, var, sequential=False)
278
-
279
- class Sequential(Module):
280
- def __init__(self, modules: Iterable[Chainable], steps: int):
281
- defaults = dict(steps=steps)
282
- super().__init__(defaults)
283
- self.set_children_sequence(modules)
284
-
285
- @torch.no_grad
286
- def step(self, var):
287
- return _sequential_step(self, var, sequential=True)
288
-
289
-
290
- class GradientAccumulation(Module):
291
- """gradient accumulation"""
292
- def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
293
- defaults = dict(n=n, mean=mean, stop=stop)
294
- super().__init__(defaults)
295
- self.set_child('modules', modules)
296
-
297
-
298
- @torch.no_grad
299
- def step(self, var):
300
- accumulator = self.get_state(var.params, 'accumulator')
301
- settings = self.settings[var.params[0]]
302
- n = settings['n']; mean = settings['mean']; stop = settings['stop']
303
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
304
-
305
- # add update to accumulator
306
- torch._foreach_add_(accumulator, var.get_update())
307
-
308
- # step with accumulated updates
309
- if step % n == 0:
310
- if mean:
311
- torch._foreach_div_(accumulator, n)
312
-
313
- var.update = [a.clone() for a in accumulator]
314
- var = self.children['modules'].step(var)
315
-
316
- # zero accumulator
317
- torch._foreach_zero_(accumulator)
318
-
319
- else:
320
- # prevent update
321
- if stop:
322
- var.stop=True
323
- var.skip_update=True
324
-
325
- return var
326
-
327
-
328
- class Dropout(Transform):
329
- def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
330
- defaults = dict(p=p, graft=graft)
331
- super().__init__(defaults, uses_grad=False, target=target)
332
-
333
- @torch.no_grad
334
- def apply(self, tensors, params, grads, loss, states, settings):
335
- tensors = TensorList(tensors)
336
- p = NumberList(s['p'] for s in settings)
337
- graft = settings[0]['graft']
338
-
339
- if graft:
340
- target_norm = tensors.global_vector_norm()
341
- tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
342
- return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
343
-
344
- return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
345
-
346
- class WeightDropout(Module):
347
- """Applies dropout directly to weights."""
348
- def __init__(self, p: float = 0.5, graft: bool = True):
349
- defaults = dict(p=p, graft=graft)
350
- super().__init__(defaults)
351
-
352
- @torch.no_grad
353
- def step(self, var):
354
- closure = var.closure
355
- if closure is None: raise RuntimeError('WeightDropout requires closure')
356
- params = TensorList(var.params)
357
- p = NumberList(self.settings[p]['p'] for p in params)
358
- mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
359
-
360
- @torch.no_grad
361
- def dropout_closure(backward=True):
362
- orig_params = params.clone()
363
- params.mul_(mask)
364
- if backward:
365
- with torch.enable_grad(): loss = closure()
366
- else:
367
- loss = closure(False)
368
- params.copy_(orig_params)
369
- return loss
370
-
371
- var.closure = dropout_closure
372
- return var
373
-
374
- class NoiseSign(Transform):
375
- """uses random vector with update sign"""
376
- def __init__(self, distribution:Distributions = 'normal', alpha = 1):
377
- defaults = dict(distribution=distribution, alpha=alpha)
378
- super().__init__(defaults, uses_grad=False)
379
-
380
- @torch.no_grad
381
- def apply(self, tensors, params, grads, loss, states, settings):
382
- alpha = [s['alpha'] for s in settings]
383
- distribution = self.settings[params[0]]['distribution']
384
- return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
385
-
386
-
387
- class NegateOnLossIncrease(Module):
388
- def __init__(self, backtrack=True):
389
- defaults = dict(backtrack=backtrack)
390
- super().__init__(defaults=defaults)
391
-
392
- @torch.no_grad
393
- def step(self, var):
394
- closure = var.closure
395
- if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
396
- backtrack = self.settings[var.params[0]]['backtrack']
397
-
398
- update = var.get_update()
399
- f_0 = var.get_loss(backward=False)
400
-
401
- torch._foreach_sub_(var.params, update)
402
- f_1 = closure(False)
403
-
404
- if f_1 <= f_0:
405
- if var.is_last and var.last_module_lrs is None:
406
- var.stop = True
407
- var.skip_update = True
408
- return var
409
-
410
- torch._foreach_add_(var.params, update)
411
- return var
412
-
413
- torch._foreach_add_(var.params, update)
414
- if backtrack:
415
- torch._foreach_neg_(var.update)
416
- else:
417
- torch._foreach_zero_(var.update)
418
- return var
@@ -1 +0,0 @@
1
- from .modular_lbfgs import ModularLBFGS
@@ -1,196 +0,0 @@
1
- from collections import deque
2
- from functools import partial
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Chainable, Module, Transform, Var, apply_transform
9
- from ...utils import NumberList, TensorList, as_tensorlist
10
- from .lbfgs import _adaptive_damping, lbfgs
11
-
12
-
13
- @torch.no_grad
14
- def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
- assert var.closure is not None
16
- with torch.enable_grad(): var.closure()
17
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
18
- s_k = var.params - prev_params
19
- y_k = grad - prev_grad
20
- ys_k = s_k.dot(y_k)
21
-
22
- if damping:
23
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
24
-
25
- if ys_k > 1e-10:
26
- s_history.append(s_k)
27
- y_history.append(y_k)
28
- sy_history.append(ys_k)
29
-
30
-
31
-
32
- class OnlineLBFGS(Module):
33
- """Online L-BFGS.
34
- Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
35
- However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
36
- better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
37
-
38
- Args:
39
- history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
40
- sample_grads (str, optional):
41
- - "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
42
- and adds it to history before stepping.
43
- - "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
44
- s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
45
-
46
- In practice both modes behave very similarly. Defaults to 'before'.
47
- tol (float | None, optional):
48
- tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
49
- damping (bool, optional):
50
- whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
51
- init_damping (float, optional):
52
- initial damping for adaptive dampening. Defaults to 0.9.
53
- eigval_bounds (tuple, optional):
54
- eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
55
- params_beta (float | None, optional):
56
- if not None, EMA of parameters is used for preconditioner update. Defaults to None.
57
- grads_beta (float | None, optional):
58
- if not None, EMA of gradients is used for preconditioner update. Defaults to None.
59
- update_freq (int, optional):
60
- how often to update L-BFGS history. Defaults to 1.
61
- z_beta (float | None, optional):
62
- optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
63
- inner (Chainable | None, optional):
64
- optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
65
- """
66
- def __init__(
67
- self,
68
- history_size=10,
69
- sample_grads: Literal['before', 'after'] = 'before',
70
- tol: float | None = 1e-10,
71
- damping: bool = False,
72
- init_damping=0.9,
73
- eigval_bounds=(0.5, 50),
74
- z_beta: float | None = None,
75
- inner: Chainable | None = None,
76
- ):
77
- defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
78
- super().__init__(defaults)
79
-
80
- self.global_state['s_history'] = deque(maxlen=history_size)
81
- self.global_state['y_history'] = deque(maxlen=history_size)
82
- self.global_state['sy_history'] = deque(maxlen=history_size)
83
-
84
- if inner is not None:
85
- self.set_child('inner', inner)
86
-
87
- def reset(self):
88
- """Resets the internal state of the L-SR1 module."""
89
- # super().reset() # Clears self.state (per-parameter) if any, and "step"
90
- # Re-initialize L-SR1 specific global state
91
- self.state.clear()
92
- self.global_state['step'] = 0
93
- self.global_state['s_history'].clear()
94
- self.global_state['y_history'].clear()
95
- self.global_state['sy_history'].clear()
96
-
97
- @torch.no_grad
98
- def step(self, var):
99
- assert var.closure is not None
100
-
101
- params = as_tensorlist(var.params)
102
- update = as_tensorlist(var.get_update())
103
- step = self.global_state.get('step', 0)
104
- self.global_state['step'] = step + 1
105
-
106
- # history of s and k
107
- s_history: deque[TensorList] = self.global_state['s_history']
108
- y_history: deque[TensorList] = self.global_state['y_history']
109
- sy_history: deque[torch.Tensor] = self.global_state['sy_history']
110
-
111
- tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
112
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
113
-
114
- # sample gradient at previous params with current mini-batch
115
- if sample_grads == 'before':
116
- prev_params = self.get_state(params, 'prev_params', cls=TensorList)
117
- if step == 0:
118
- s_k = None; y_k = None; ys_k = None
119
- else:
120
- s_k = params - prev_params
121
-
122
- current_params = params.clone()
123
- params.set_(prev_params)
124
- with torch.enable_grad(): var.closure()
125
- y_k = update - params.grad
126
- ys_k = s_k.dot(y_k)
127
- params.set_(current_params)
128
-
129
- if damping:
130
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
131
-
132
- if ys_k > 1e-10:
133
- s_history.append(s_k)
134
- y_history.append(y_k)
135
- sy_history.append(ys_k)
136
-
137
- prev_params.copy_(params)
138
-
139
- # use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
140
- elif sample_grads == 'after':
141
- if len(s_history) == 0:
142
- s_k = None; y_k = None; ys_k = None
143
- else:
144
- s_k = s_history[-1]
145
- y_k = y_history[-1]
146
- ys_k = s_k.dot(y_k)
147
-
148
- # this will run after params are updated by Modular after running all future modules
149
- var.post_step_hooks.append(
150
- partial(
151
- _store_sk_yk_after_step_hook,
152
- prev_params=params.clone(),
153
- prev_grad=update.clone(),
154
- damping=damping,
155
- init_damping=init_damping,
156
- eigval_bounds=eigval_bounds,
157
- s_history=s_history,
158
- y_history=y_history,
159
- sy_history=sy_history,
160
- ))
161
-
162
- else:
163
- raise ValueError(sample_grads)
164
-
165
- # step with inner module before applying preconditioner
166
- if self.children:
167
- update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
168
-
169
- # tolerance on gradient difference to avoid exploding after converging
170
- if tol is not None:
171
- if y_k is not None and y_k.abs().global_max() <= tol:
172
- var.update = update # may have been updated by inner module, probably makes sense to use it here?
173
- return var
174
-
175
- # lerp initial H^-1 @ q guess
176
- z_ema = None
177
- if z_beta is not None:
178
- z_ema = self.get_state(params, 'z_ema', cls=TensorList)
179
-
180
- # precondition
181
- dir = lbfgs(
182
- tensors_=as_tensorlist(update),
183
- s_history=s_history,
184
- y_history=y_history,
185
- sy_history=sy_history,
186
- y_k=y_k,
187
- ys_k=ys_k,
188
- z_beta = z_beta,
189
- z_ema = z_ema,
190
- step=step
191
- )
192
-
193
- var.update = dir
194
-
195
- return var
196
-