torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,25 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
-
5
- from ...core import Module
6
- from ...utils.tensorlist import Distributions
7
-
8
- class PrintUpdate(Module):
9
- def __init__(self, text = 'update = ', print_fn = print):
10
- defaults = dict(text=text, print_fn=print_fn)
11
- super().__init__(defaults)
12
-
13
- def step(self, vars):
14
- self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{vars.update}')
15
- return vars
16
-
17
- class PrintShape(Module):
18
- def __init__(self, text = 'shapes = ', print_fn = print):
19
- defaults = dict(text=text, print_fn=print_fn)
20
- super().__init__(defaults)
21
-
22
- def step(self, vars):
23
- shapes = [u.shape for u in vars.update] if vars.update is not None else None
24
- self.settings[vars.params[0]]["print_fn"](f'{self.settings[vars.params[0]]["text"]}{shapes}')
25
- return vars
@@ -1,419 +0,0 @@
1
- from collections import deque
2
- from collections.abc import Iterable
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Chainable, Module, TensorwiseTransform, Target, Transform, Vars
9
- from ...utils import Distributions, NumberList, TensorList
10
-
11
-
12
- class Previous(TensorwiseTransform):
13
- """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
- def __init__(self, n=1, target: Target = 'update'):
15
- defaults = dict(n=n)
16
- super().__init__(uses_grad=False, defaults=defaults, target=target)
17
-
18
-
19
- @torch.no_grad
20
- def transform(self, tensor, param, grad, vars):
21
- n = self.settings[param]['n']
22
- state = self.state[param]
23
-
24
- if 'history' not in state:
25
- state['history'] = deque(maxlen=n+1)
26
-
27
- state['history'].append(tensor)
28
-
29
- return state['history'][0]
30
-
31
-
32
- class LastDifference(Transform):
33
- """Difference between past two updates."""
34
- def __init__(self,target: Target = 'update'):
35
- super().__init__({}, uses_grad=False, target=target)
36
-
37
- @torch.no_grad
38
- def transform(self, tensors, params, grads, vars):
39
- prev_target = self.get_state('prev_target', params=params) # initialized to 0
40
- difference = torch._foreach_sub(tensors, prev_target)
41
- for p, c in zip(prev_target, tensors): p.set_(c)
42
- return difference
43
-
44
- class LastGradDifference(Module):
45
- """Difference between past two grads."""
46
- def __init__(self):
47
- super().__init__({})
48
-
49
- @torch.no_grad
50
- def step(self, vars):
51
- grad = vars.get_grad()
52
- prev_grad = self.get_state('prev_grad', params=vars.params) # initialized to 0
53
- difference = torch._foreach_sub(grad, prev_grad)
54
- for p, c in zip(prev_grad, grad): p.set_(c)
55
- vars.update = list(difference)
56
- return vars
57
-
58
-
59
- class LastProduct(Transform):
60
- """Difference between past two updates."""
61
- def __init__(self,target: Target = 'update'):
62
- super().__init__({}, uses_grad=False, target=target)
63
-
64
- @torch.no_grad
65
- def transform(self, tensors, params, grads, vars):
66
- prev_target = self.get_state('prev_target', params=params, init=torch.ones_like) # initialized to 1 for prod
67
- prod = torch._foreach_mul(tensors, prev_target)
68
- for p, c in zip(prev_target, tensors): p.set_(c)
69
- return prod
70
-
71
- class LastRatio(Transform):
72
- """Ratio between past two updates."""
73
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
74
- defaults = dict(numerator=numerator)
75
- super().__init__(defaults, uses_grad=False, target=target)
76
-
77
- @torch.no_grad
78
- def transform(self, tensors, params, grads, vars):
79
- prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to ones
80
- numerator = self.settings[params[0]]['numerator']
81
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
82
- else: ratio = torch._foreach_div(prev_target, tensors)
83
- for p, c in zip(prev_target, tensors): p.set_(c)
84
- return ratio
85
-
86
- class LastAbsoluteRatio(Transform):
87
- """Ratio between absolute values of past two updates."""
88
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
89
- defaults = dict(numerator=numerator, eps=eps)
90
- super().__init__(defaults, uses_grad=False, target=target)
91
-
92
- @torch.no_grad
93
- def transform(self, tensors, params, grads, vars):
94
- prev_target = self.get_state('prev_target', params=params, init = torch.ones_like) # initialized to 0
95
- numerator = self.settings[params[0]]['numerator']
96
- eps = self.get_settings('eps', params=params, cls = NumberList)
97
-
98
- torch._foreach_abs_(tensors)
99
- torch._foreach_clamp_min_(prev_target, eps)
100
-
101
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev_target)
102
- else: ratio = torch._foreach_div(prev_target, tensors)
103
- for p, c in zip(prev_target, tensors): p.set_(c)
104
- return ratio
105
-
106
- class GradSign(Transform):
107
- """copy gradient sign to update."""
108
- def __init__(self, target: Target = 'update'):
109
- super().__init__({}, uses_grad=True, target=target)
110
-
111
- @torch.no_grad
112
- def transform(self, tensors, params, grads, vars):
113
- assert grads is not None
114
- return [t.copysign_(g) for t,g in zip(tensors, grads)]
115
-
116
- class UpdateSign(Transform):
117
- """use per-weight magnitudes from grad while using sign from update."""
118
- def __init__(self, target: Target = 'update'):
119
- super().__init__({}, uses_grad=True, target=target)
120
-
121
- @torch.no_grad
122
- def transform(self, tensors, params, grads, vars):
123
- assert grads is not None
124
- return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
125
-
126
- class GraftToGrad(Transform):
127
- """use gradient norm and update direction."""
128
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
129
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
130
- super().__init__(defaults, uses_grad=True, target=target)
131
-
132
- @torch.no_grad
133
- def transform(self, tensors, params, grads, vars):
134
- assert grads is not None
135
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
136
- return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
137
-
138
- class GraftGradToUpdate(Transform):
139
- """use update norm and gradient direction."""
140
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
141
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
142
- super().__init__(defaults, uses_grad=True, target=target)
143
-
144
- @torch.no_grad
145
- def transform(self, tensors, params, grads, vars):
146
- assert grads is not None
147
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
148
- return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
149
-
150
-
151
- class GraftToParams(Transform):
152
- """makes update norm be set to parameter norm, but norm won't go below eps"""
153
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
154
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
155
- super().__init__(defaults, uses_grad=False, target=target)
156
-
157
- @torch.no_grad
158
- def transform(self, tensors, params, grads, vars):
159
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[params[0]])
160
- return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
161
-
162
- class Relative(Transform):
163
- """multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
164
- def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
165
- defaults = dict(min_value=min_value)
166
- super().__init__(defaults, uses_grad=False, target=target)
167
-
168
- @torch.no_grad
169
- def transform(self, tensors, params, grads, vars):
170
- mul = TensorList(params).abs().clamp_(self.get_settings('min_value', params=params))
171
- torch._foreach_mul_(tensors, mul)
172
- return tensors
173
-
174
- class FillLoss(Module):
175
- """makes tensors filled with loss value times alpha"""
176
- def __init__(self, alpha: float = 1, backward: bool = True):
177
- defaults = dict(alpha=alpha, backward=backward)
178
- super().__init__(defaults)
179
-
180
- @torch.no_grad
181
- def step(self, vars):
182
- alpha = self.get_settings('alpha', params=vars.params)
183
- loss = vars.get_loss(backward=self.settings[vars.params[0]]['backward'])
184
- vars.update = [torch.full_like(p, loss*a) for p,a in zip(vars.params, alpha)]
185
- return vars
186
-
187
- class MulByLoss(Transform):
188
- """multiplies update by loss times alpha"""
189
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
190
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
191
- super().__init__(defaults, uses_grad=False, target=target)
192
-
193
- @torch.no_grad
194
- def transform(self, tensors, params, grads, vars): #vars used for loss
195
- alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
196
- loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
197
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
198
- torch._foreach_mul_(tensors, mul)
199
- return tensors
200
-
201
- class DivByLoss(Transform):
202
- """divides update by loss times alpha"""
203
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True, target: Target = 'update'):
204
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
205
- super().__init__(defaults, uses_grad=False, target=target)
206
-
207
- @torch.no_grad
208
- def transform(self, tensors, params, grads, vars): #vars used for loss
209
- alpha, min_value = self.get_settings('alpha', 'min_value', params=params)
210
- loss = vars.get_loss(backward=self.settings[params[0]]['backward'])
211
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
212
- torch._foreach_div_(tensors, mul)
213
- return tensors
214
-
215
-
216
-
217
- def _sequential_step(self: Module, vars: Vars, sequential: bool):
218
- params = vars.params
219
- steps = self.settings[params[0]]['steps']
220
-
221
- if sequential: modules = self.get_children_sequence()
222
- else: modules = [self.children['module']] * steps
223
-
224
- if vars.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
225
-
226
- # store original params unless this is last module and can update params directly
227
- params_before_steps = None if (vars.is_last and vars.last_module_lrs is None) else [p.clone() for p in params]
228
-
229
- # first step - pass vars as usual
230
- vars = modules[0].step(vars)
231
- new_vars = vars
232
-
233
- # subsequent steps - update parameters and create new vars
234
- if len(modules) > 1:
235
- for m in modules[1:]:
236
-
237
- # update params
238
- if (not new_vars.skip_update):
239
- if new_vars.last_module_lrs is not None:
240
- torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
241
-
242
- torch._foreach_sub_(params, new_vars.get_update())
243
-
244
- # create new vars since we are at a new point, that means grad, update and loss will be None
245
- new_vars = Vars(params=new_vars.params, closure=new_vars.closure,
246
- model=new_vars.model, current_step=new_vars.current_step + 1)
247
-
248
- # step
249
- new_vars = m.step(new_vars)
250
-
251
- # final parameter update
252
- if (not new_vars.skip_update):
253
- if new_vars.last_module_lrs is not None:
254
- torch._foreach_mul_(new_vars.get_update(), new_vars.last_module_lrs)
255
-
256
- torch._foreach_sub_(params, new_vars.get_update())
257
-
258
- # if last module, update is applied so return new vars
259
- if params_before_steps is None:
260
- new_vars.stop = True
261
- new_vars.skip_update = True
262
- return new_vars
263
-
264
- # otherwise use parameter difference as update
265
- vars.update = list(torch._foreach_sub(params_before_steps, params))
266
- for p, bef in zip(params, params_before_steps):
267
- p.set_(bef) # pyright:ignore[reportArgumentType]
268
- return vars
269
-
270
- class Multistep(Module):
271
- def __init__(self, module: Chainable, steps: int):
272
- defaults = dict(steps=steps)
273
- super().__init__(defaults)
274
- self.set_child('module', module)
275
-
276
- @torch.no_grad
277
- def step(self, vars):
278
- return _sequential_step(self, vars, sequential=False)
279
-
280
- class Sequential(Module):
281
- def __init__(self, modules: Iterable[Chainable], steps: int):
282
- defaults = dict(steps=steps)
283
- super().__init__(defaults)
284
- self.set_children_sequence(modules)
285
-
286
- @torch.no_grad
287
- def step(self, vars):
288
- return _sequential_step(self, vars, sequential=True)
289
-
290
-
291
- class GradientAccumulation(Module):
292
- """gradient accumulation"""
293
- def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
294
- defaults = dict(n=n, mean=mean, stop=stop)
295
- super().__init__(defaults)
296
- self.set_child('modules', modules)
297
-
298
-
299
- @torch.no_grad
300
- def step(self, vars):
301
- accumulator = self.get_state('accumulator', params=vars.params)
302
- settings = self.settings[vars.params[0]]
303
- n = settings['n']; mean = settings['mean']; stop = settings['stop']
304
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
305
-
306
- # add update to accumulator
307
- torch._foreach_add_(accumulator, vars.get_update())
308
-
309
- # step with accumulated updates
310
- if step % n == 0:
311
- if mean:
312
- torch._foreach_div_(accumulator, n)
313
-
314
- vars.update = [a.clone() for a in accumulator]
315
- vars = self.children['modules'].step(vars)
316
-
317
- # zero accumulator
318
- torch._foreach_zero_(accumulator)
319
-
320
- else:
321
- # prevent update
322
- if stop:
323
- vars.stop=True
324
- vars.skip_update=True
325
-
326
- return vars
327
-
328
-
329
- class Dropout(Transform):
330
- def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
331
- defaults = dict(p=p, graft=graft)
332
- super().__init__(defaults, uses_grad=False, target=target)
333
-
334
- @torch.no_grad
335
- def transform(self, tensors, params, grads, vars):
336
- tensors = TensorList(tensors)
337
- p = self.get_settings('p', params=params, cls=NumberList)
338
- graft = self.settings[params[0]]['graft']
339
-
340
- if graft:
341
- target_norm = tensors.global_vector_norm()
342
- tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
343
- return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
344
-
345
- return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
346
-
347
- class WeightDropout(Module):
348
- """Applies dropout directly to weights."""
349
- def __init__(self, p: float = 0.5, graft: bool = True):
350
- defaults = dict(p=p, graft=graft)
351
- super().__init__(defaults)
352
-
353
- @torch.no_grad
354
- def step(self, vars):
355
- closure = vars.closure
356
- if closure is None: raise RuntimeError('WeightDropout requires closure')
357
- params = TensorList(vars.params)
358
- p = self.get_settings('p', params=params)
359
- mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
360
-
361
- @torch.no_grad
362
- def dropout_closure(backward=True):
363
- orig_params = params.clone()
364
- params.mul_(mask)
365
- if backward:
366
- with torch.enable_grad(): loss = closure()
367
- else:
368
- loss = closure(False)
369
- params.copy_(orig_params)
370
- return loss
371
-
372
- vars.closure = dropout_closure
373
- return vars
374
-
375
- class NoiseSign(Transform):
376
- """uses random vector with update sign"""
377
- def __init__(self, distribution:Distributions = 'normal', alpha = 1):
378
- defaults = dict(distribution=distribution, alpha=alpha)
379
- super().__init__(defaults, uses_grad=False)
380
-
381
- @torch.no_grad
382
- def transform(self, tensors, params, grads, vars):
383
- alpha = self.get_settings('alpha', params=params)
384
- distribution = self.settings[params[0]]['distribution']
385
- return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
386
-
387
-
388
- class NegateOnLossIncrease(Module):
389
- def __init__(self, backtrack=True):
390
- defaults = dict(backtrack=backtrack)
391
- super().__init__(defaults=defaults)
392
-
393
- @torch.no_grad
394
- def step(self, vars):
395
- closure = vars.closure
396
- if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
397
- backtrack = self.settings[vars.params[0]]['backtrack']
398
-
399
- update = vars.get_update()
400
- f_0 = vars.get_loss(backward=False)
401
-
402
- torch._foreach_sub_(vars.params, update)
403
- f_1 = closure(False)
404
-
405
- if f_1 <= f_0:
406
- if vars.is_last and vars.last_module_lrs is None:
407
- vars.stop = True
408
- vars.skip_update = True
409
- return vars
410
-
411
- torch._foreach_add_(vars.params, update)
412
- return vars
413
-
414
- torch._foreach_add_(vars.params, update)
415
- if backtrack:
416
- torch._foreach_neg_(vars.update)
417
- else:
418
- torch._foreach_zero_(vars.update)
419
- return vars
@@ -1,75 +0,0 @@
1
- from collections.abc import Callable
2
- from typing import cast
3
-
4
- import torch
5
-
6
- from ...core import Chainable, Module, Vars
7
-
8
-
9
- def _split(
10
- module: Module,
11
- idxs,
12
- params,
13
- vars: Vars,
14
- ):
15
- split_params = [p for i,p in enumerate(params) if i in idxs]
16
-
17
- split_grad = None
18
- if vars.grad is not None:
19
- split_grad = [g for i,g in enumerate(vars.grad) if i in idxs]
20
-
21
- split_update = None
22
- if vars.update is not None:
23
- split_update = [u for i,u in enumerate(vars.update) if i in idxs]
24
-
25
- split_vars = vars.clone(clone_update=False)
26
- split_vars.params = split_params
27
- split_vars.grad = split_grad
28
- split_vars.update = split_update
29
-
30
- split_vars = module.step(split_vars)
31
-
32
- if (vars.grad is None) and (split_vars.grad is not None):
33
- vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
34
-
35
- if split_vars.update is not None:
36
-
37
- if vars.update is None:
38
- if vars.grad is None: vars.update = [cast(torch.Tensor, None) for _ in vars.params]
39
- else: vars.update = [g.clone() for g in vars.grad]
40
-
41
- for idx, u in zip(idxs, split_vars.update):
42
- vars.update[idx] = u
43
-
44
- vars.update_attrs_from_clone_(split_vars)
45
- return vars
46
-
47
- class Split(Module):
48
- """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
49
- def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
50
- defaults = dict(filter=filter)
51
- super().__init__(defaults)
52
-
53
- if true is not None: self.set_child('true', true)
54
- if false is not None: self.set_child('false', false)
55
-
56
- def step(self, vars):
57
-
58
- params = vars.params
59
- filter = self.settings[params[0]]['filter']
60
-
61
- true_idxs = []
62
- false_idxs = []
63
- for i,p in enumerate(params):
64
- if filter(p): true_idxs.append(i)
65
- else: false_idxs.append(i)
66
-
67
- if 'true' in self.children:
68
- true = self.children['true']
69
- vars = _split(true, idxs=true_idxs, params=params, vars=vars)
70
-
71
- if 'false' in self.children:
72
- false = self.children['false']
73
- vars = _split(false, idxs=false_idxs, params=params, vars=vars)
74
-
75
- return vars
@@ -1 +0,0 @@
1
- from .modular_lbfgs import ModularLBFGS