torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,25 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
-
5
- from ...core import Module
6
- from ...utils.tensorlist import Distributions
7
-
8
- class PrintUpdate(Module):
9
- def __init__(self, text = 'update = ', print_fn = print):
10
- defaults = dict(text=text, print_fn=print_fn)
11
- super().__init__(defaults)
12
-
13
- def step(self, var):
14
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.update}')
15
- return var
16
-
17
- class PrintShape(Module):
18
- def __init__(self, text = 'shapes = ', print_fn = print):
19
- defaults = dict(text=text, print_fn=print_fn)
20
- super().__init__(defaults)
21
-
22
- def step(self, var):
23
- shapes = [u.shape for u in var.update] if var.update is not None else None
24
- self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
25
- return var
@@ -1,418 +0,0 @@
1
- from collections import deque
2
- from collections.abc import Iterable
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Chainable, Module, Target, TensorwiseTransform, Transform, Var
9
- from ...utils import Distributions, NumberList, TensorList, unpack_dicts, unpack_states
10
-
11
-
12
- class Previous(TensorwiseTransform):
13
- """Maintains an update from n steps back, for example if n=1, returns previous update"""
14
- def __init__(self, n=1, target: Target = 'update'):
15
- defaults = dict(n=n)
16
- super().__init__(uses_grad=False, defaults=defaults, target=target)
17
-
18
-
19
- @torch.no_grad
20
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
21
- n = settings['n']
22
-
23
- if 'history' not in state:
24
- state['history'] = deque(maxlen=n+1)
25
-
26
- state['history'].append(tensor)
27
-
28
- return state['history'][0]
29
-
30
-
31
- class LastDifference(Transform):
32
- """Difference between past two updates."""
33
- def __init__(self,target: Target = 'update'):
34
- super().__init__({}, uses_grad=False, target=target)
35
-
36
- @torch.no_grad
37
- def apply(self, tensors, params, grads, loss, states, settings):
38
- prev = unpack_states(states, tensors, 'prev_target') # initialized to 0
39
- difference = torch._foreach_sub(tensors, prev)
40
- for p, c in zip(prev, tensors): p.set_(c)
41
- return difference
42
-
43
- class LastGradDifference(Module):
44
- """Difference between past two grads."""
45
- def __init__(self):
46
- super().__init__({})
47
-
48
- @torch.no_grad
49
- def step(self, var):
50
- grad = var.get_grad()
51
- prev_grad = self.get_state(var.params, 'prev_grad') # initialized to 0
52
- difference = torch._foreach_sub(grad, prev_grad)
53
- for p, c in zip(prev_grad, grad): p.set_(c)
54
- var.update = list(difference)
55
- return var
56
-
57
-
58
- class LastProduct(Transform):
59
- """Difference between past two updates."""
60
- def __init__(self,target: Target = 'update'):
61
- super().__init__({}, uses_grad=False, target=target)
62
-
63
- @torch.no_grad
64
- def apply(self, tensors, params, grads, loss, states, settings):
65
- prev = unpack_states(states, tensors, 'prev', init=torch.ones_like) # initialized to 1 for prod
66
- prod = torch._foreach_mul(tensors, prev)
67
- for p, c in zip(prev, tensors): p.set_(c)
68
- return prod
69
-
70
- class LastRatio(Transform):
71
- """Ratio between past two updates."""
72
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', target: Target = 'update'):
73
- defaults = dict(numerator=numerator)
74
- super().__init__(defaults, uses_grad=False, target=target)
75
-
76
- @torch.no_grad
77
- def apply(self, tensors, params, grads, loss, states, settings):
78
- prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
79
- numerator = settings[0]['numerator']
80
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
81
- else: ratio = torch._foreach_div(prev, tensors)
82
- for p, c in zip(prev, tensors): p.set_(c)
83
- return ratio
84
-
85
- class LastAbsoluteRatio(Transform):
86
- """Ratio between absolute values of past two updates."""
87
- def __init__(self, numerator: Literal['cur', 'prev'] = 'cur', eps:float=1e-8, target: Target = 'update'):
88
- defaults = dict(numerator=numerator, eps=eps)
89
- super().__init__(defaults, uses_grad=False, target=target)
90
-
91
- @torch.no_grad
92
- def apply(self, tensors, params, grads, loss, states, settings):
93
- prev = unpack_states(states, tensors, 'prev', init = torch.ones_like) # initialized to ones
94
- numerator = settings[0]['numerator']
95
- eps = NumberList(s['eps'] for s in settings)
96
-
97
- torch._foreach_abs_(tensors)
98
- torch._foreach_clamp_min_(prev, eps)
99
-
100
- if numerator == 'cur': ratio = torch._foreach_div(tensors, prev)
101
- else: ratio = torch._foreach_div(prev, tensors)
102
- for p, c in zip(prev, tensors): p.set_(c)
103
- return ratio
104
-
105
- class GradSign(Transform):
106
- """copy gradient sign to update."""
107
- def __init__(self, target: Target = 'update'):
108
- super().__init__({}, uses_grad=True, target=target)
109
-
110
- @torch.no_grad
111
- def apply(self, tensors, params, grads, loss, states, settings):
112
- assert grads is not None
113
- return [t.copysign_(g) for t,g in zip(tensors, grads)]
114
-
115
- class UpdateSign(Transform):
116
- """use per-weight magnitudes from grad while using sign from update."""
117
- def __init__(self, target: Target = 'update'):
118
- super().__init__({}, uses_grad=True, target=target)
119
-
120
- @torch.no_grad
121
- def apply(self, tensors, params, grads, loss, states, settings):
122
- assert grads is not None
123
- return [g.copysign(t) for t,g in zip(tensors, grads)] # no in-place
124
-
125
- class GraftToGrad(Transform):
126
- """use gradient norm and update direction."""
127
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
128
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
129
- super().__init__(defaults, uses_grad=True, target=target)
130
-
131
- @torch.no_grad
132
- def apply(self, tensors, params, grads, loss, states, settings):
133
- assert grads is not None
134
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
135
- return TensorList(tensors).graft_(grads, tensorwise=tensorwise, ord=ord, eps=eps)
136
-
137
- class GraftGradToUpdate(Transform):
138
- """use update norm and gradient direction."""
139
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-6, target: Target = 'update'):
140
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
141
- super().__init__(defaults, uses_grad=True, target=target)
142
-
143
- @torch.no_grad
144
- def apply(self, tensors, params, grads, loss, states, settings):
145
- assert grads is not None
146
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
147
- return TensorList(grads).graft(tensors, tensorwise=tensorwise, ord=ord, eps=eps)
148
-
149
-
150
- class GraftToParams(Transform):
151
- """makes update norm be set to parameter norm, but norm won't go below eps"""
152
- def __init__(self, tensorwise:bool=False, ord:float=2, eps:float = 1e-4, target: Target = 'update'):
153
- defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps)
154
- super().__init__(defaults, uses_grad=False, target=target)
155
-
156
- @torch.no_grad
157
- def apply(self, tensors, params, grads, loss, states, settings):
158
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(settings[0])
159
- return TensorList(tensors).graft_(params, tensorwise=tensorwise, ord=ord, eps=eps)
160
-
161
- class Relative(Transform):
162
- """multiplies update by absolute parameter values to make it relative to their magnitude, min_value is minimum value to avoid getting stuck at 0"""
163
- def __init__(self, min_value:float = 1e-4, target: Target = 'update'):
164
- defaults = dict(min_value=min_value)
165
- super().__init__(defaults, uses_grad=False, target=target)
166
-
167
- @torch.no_grad
168
- def apply(self, tensors, params, grads, loss, states, settings):
169
- mul = TensorList(params).abs().clamp_([s['min_value'] for s in settings])
170
- torch._foreach_mul_(tensors, mul)
171
- return tensors
172
-
173
- class FillLoss(Module):
174
- """makes tensors filled with loss value times alpha"""
175
- def __init__(self, alpha: float = 1, backward: bool = True):
176
- defaults = dict(alpha=alpha, backward=backward)
177
- super().__init__(defaults)
178
-
179
- @torch.no_grad
180
- def step(self, var):
181
- alpha = self.get_settings(var.params, 'alpha')
182
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
183
- var.update = [torch.full_like(p, loss*a) for p,a in zip(var.params, alpha)]
184
- return var
185
-
186
- class MulByLoss(Module):
187
- """multiplies update by loss times alpha"""
188
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
189
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
190
- super().__init__(defaults)
191
-
192
- @torch.no_grad
193
- def step(self, var):
194
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
195
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
196
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
197
- torch._foreach_mul_(var.update, mul)
198
- return var
199
-
200
- class DivByLoss(Module):
201
- """divides update by loss times alpha"""
202
- def __init__(self, alpha: float = 1, min_value:float = 1e-8, backward: bool = True):
203
- defaults = dict(alpha=alpha, min_value=min_value, backward=backward)
204
- super().__init__(defaults)
205
-
206
- @torch.no_grad
207
- def step(self, var):
208
- alpha, min_value = self.get_settings(var.params, 'alpha', 'min_value')
209
- loss = var.get_loss(backward=self.settings[var.params[0]]['backward'])
210
- mul = [max(loss*a, mv) for a,mv in zip(alpha, min_value)]
211
- torch._foreach_div_(var.update, mul)
212
- return var
213
-
214
-
215
-
216
- def _sequential_step(self: Module, var: Var, sequential: bool):
217
- params = var.params
218
- steps = self.settings[params[0]]['steps']
219
-
220
- if sequential: modules = self.get_children_sequence()
221
- else: modules = [self.children['module']] * steps
222
-
223
- if var.closure is None and len(modules) > 1: raise ValueError('Multistep and Sequential require closure')
224
-
225
- # store original params unless this is last module and can update params directly
226
- params_before_steps = None if (var.is_last and var.last_module_lrs is None) else [p.clone() for p in params]
227
-
228
- # first step - pass var as usual
229
- var = modules[0].step(var)
230
- new_var = var
231
-
232
- # subsequent steps - update parameters and create new var
233
- if len(modules) > 1:
234
- for m in modules[1:]:
235
-
236
- # update params
237
- if (not new_var.skip_update):
238
- if new_var.last_module_lrs is not None:
239
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
240
-
241
- torch._foreach_sub_(params, new_var.get_update())
242
-
243
- # create new var since we are at a new point, that means grad, update and loss will be None
244
- new_var = Var(params=new_var.params, closure=new_var.closure,
245
- model=new_var.model, current_step=new_var.current_step + 1)
246
-
247
- # step
248
- new_var = m.step(new_var)
249
-
250
- # final parameter update
251
- if (not new_var.skip_update):
252
- if new_var.last_module_lrs is not None:
253
- torch._foreach_mul_(new_var.get_update(), new_var.last_module_lrs)
254
-
255
- torch._foreach_sub_(params, new_var.get_update())
256
-
257
- # if last module, update is applied so return new var
258
- if params_before_steps is None:
259
- new_var.stop = True
260
- new_var.skip_update = True
261
- return new_var
262
-
263
- # otherwise use parameter difference as update
264
- var.update = list(torch._foreach_sub(params_before_steps, params))
265
- for p, bef in zip(params, params_before_steps):
266
- p.set_(bef) # pyright:ignore[reportArgumentType]
267
- return var
268
-
269
- class Multistep(Module):
270
- def __init__(self, module: Chainable, steps: int):
271
- defaults = dict(steps=steps)
272
- super().__init__(defaults)
273
- self.set_child('module', module)
274
-
275
- @torch.no_grad
276
- def step(self, var):
277
- return _sequential_step(self, var, sequential=False)
278
-
279
- class Sequential(Module):
280
- def __init__(self, modules: Iterable[Chainable], steps: int):
281
- defaults = dict(steps=steps)
282
- super().__init__(defaults)
283
- self.set_children_sequence(modules)
284
-
285
- @torch.no_grad
286
- def step(self, var):
287
- return _sequential_step(self, var, sequential=True)
288
-
289
-
290
- class GradientAccumulation(Module):
291
- """gradient accumulation"""
292
- def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
293
- defaults = dict(n=n, mean=mean, stop=stop)
294
- super().__init__(defaults)
295
- self.set_child('modules', modules)
296
-
297
-
298
- @torch.no_grad
299
- def step(self, var):
300
- accumulator = self.get_state(var.params, 'accumulator')
301
- settings = self.settings[var.params[0]]
302
- n = settings['n']; mean = settings['mean']; stop = settings['stop']
303
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
304
-
305
- # add update to accumulator
306
- torch._foreach_add_(accumulator, var.get_update())
307
-
308
- # step with accumulated updates
309
- if step % n == 0:
310
- if mean:
311
- torch._foreach_div_(accumulator, n)
312
-
313
- var.update = [a.clone() for a in accumulator]
314
- var = self.children['modules'].step(var)
315
-
316
- # zero accumulator
317
- torch._foreach_zero_(accumulator)
318
-
319
- else:
320
- # prevent update
321
- if stop:
322
- var.stop=True
323
- var.skip_update=True
324
-
325
- return var
326
-
327
-
328
- class Dropout(Transform):
329
- def __init__(self, p: float = 0.5, graft: bool=False, target: Target = 'update'):
330
- defaults = dict(p=p, graft=graft)
331
- super().__init__(defaults, uses_grad=False, target=target)
332
-
333
- @torch.no_grad
334
- def apply(self, tensors, params, grads, loss, states, settings):
335
- tensors = TensorList(tensors)
336
- p = NumberList(s['p'] for s in settings)
337
- graft = settings[0]['graft']
338
-
339
- if graft:
340
- target_norm = tensors.global_vector_norm()
341
- tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
342
- return tensors.mul_(target_norm / tensors.global_vector_norm()) # graft
343
-
344
- return tensors.mul_(tensors.rademacher_like(1-p).add_(1).div_(2))
345
-
346
- class WeightDropout(Module):
347
- """Applies dropout directly to weights."""
348
- def __init__(self, p: float = 0.5, graft: bool = True):
349
- defaults = dict(p=p, graft=graft)
350
- super().__init__(defaults)
351
-
352
- @torch.no_grad
353
- def step(self, var):
354
- closure = var.closure
355
- if closure is None: raise RuntimeError('WeightDropout requires closure')
356
- params = TensorList(var.params)
357
- p = NumberList(self.settings[p]['p'] for p in params)
358
- mask = params.rademacher_like(p).add_(1).div_(2).as_bool()
359
-
360
- @torch.no_grad
361
- def dropout_closure(backward=True):
362
- orig_params = params.clone()
363
- params.mul_(mask)
364
- if backward:
365
- with torch.enable_grad(): loss = closure()
366
- else:
367
- loss = closure(False)
368
- params.copy_(orig_params)
369
- return loss
370
-
371
- var.closure = dropout_closure
372
- return var
373
-
374
- class NoiseSign(Transform):
375
- """uses random vector with update sign"""
376
- def __init__(self, distribution:Distributions = 'normal', alpha = 1):
377
- defaults = dict(distribution=distribution, alpha=alpha)
378
- super().__init__(defaults, uses_grad=False)
379
-
380
- @torch.no_grad
381
- def apply(self, tensors, params, grads, loss, states, settings):
382
- alpha = [s['alpha'] for s in settings]
383
- distribution = self.settings[params[0]]['distribution']
384
- return TensorList(tensors).sample_like(alpha, distribution).copysign_(tensors)
385
-
386
-
387
- class NegateOnLossIncrease(Module):
388
- def __init__(self, backtrack=True):
389
- defaults = dict(backtrack=backtrack)
390
- super().__init__(defaults=defaults)
391
-
392
- @torch.no_grad
393
- def step(self, var):
394
- closure = var.closure
395
- if closure is None: raise RuntimeError('NegateOnLossIncrease requires closure')
396
- backtrack = self.settings[var.params[0]]['backtrack']
397
-
398
- update = var.get_update()
399
- f_0 = var.get_loss(backward=False)
400
-
401
- torch._foreach_sub_(var.params, update)
402
- f_1 = closure(False)
403
-
404
- if f_1 <= f_0:
405
- if var.is_last and var.last_module_lrs is None:
406
- var.stop = True
407
- var.skip_update = True
408
- return var
409
-
410
- torch._foreach_add_(var.params, update)
411
- return var
412
-
413
- torch._foreach_add_(var.params, update)
414
- if backtrack:
415
- torch._foreach_neg_(var.update)
416
- else:
417
- torch._foreach_zero_(var.update)
418
- return var
@@ -1,75 +0,0 @@
1
- from collections.abc import Callable
2
- from typing import cast
3
-
4
- import torch
5
-
6
- from ...core import Chainable, Module, Var
7
-
8
-
9
- def _split(
10
- module: Module,
11
- idxs,
12
- params,
13
- var: Var,
14
- ):
15
- split_params = [p for i,p in enumerate(params) if i in idxs]
16
-
17
- split_grad = None
18
- if var.grad is not None:
19
- split_grad = [g for i,g in enumerate(var.grad) if i in idxs]
20
-
21
- split_update = None
22
- if var.update is not None:
23
- split_update = [u for i,u in enumerate(var.update) if i in idxs]
24
-
25
- split_var = var.clone(clone_update=False)
26
- split_var.params = split_params
27
- split_var.grad = split_grad
28
- split_var.update = split_update
29
-
30
- split_var = module.step(split_var)
31
-
32
- if (var.grad is None) and (split_var.grad is not None):
33
- var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
34
-
35
- if split_var.update is not None:
36
-
37
- if var.update is None:
38
- if var.grad is None: var.update = [cast(torch.Tensor, None) for _ in var.params]
39
- else: var.update = [g.clone() for g in var.grad]
40
-
41
- for idx, u in zip(idxs, split_var.update):
42
- var.update[idx] = u
43
-
44
- var.update_attrs_from_clone_(split_var)
45
- return var
46
-
47
- class Split(Module):
48
- """Apply `true` modules to all parameters filtered by `filter`, apply `false` modules to all other parameters."""
49
- def __init__(self, filter: Callable[[torch.Tensor], bool], true: Chainable | None, false: Chainable | None):
50
- defaults = dict(filter=filter)
51
- super().__init__(defaults)
52
-
53
- if true is not None: self.set_child('true', true)
54
- if false is not None: self.set_child('false', false)
55
-
56
- def step(self, var):
57
-
58
- params = var.params
59
- filter = self.settings[params[0]]['filter']
60
-
61
- true_idxs = []
62
- false_idxs = []
63
- for i,p in enumerate(params):
64
- if filter(p): true_idxs.append(i)
65
- else: false_idxs.append(i)
66
-
67
- if 'true' in self.children:
68
- true = self.children['true']
69
- var = _split(true, idxs=true_idxs, params=params, var=var)
70
-
71
- if 'false' in self.children:
72
- false = self.children['false']
73
- var = _split(false, idxs=false_idxs, params=params, var=var)
74
-
75
- return var
@@ -1,18 +0,0 @@
1
- from .adagrad import Adagrad, FullMatrixAdagrad
2
- from .adam import Adam
3
- from .lion import Lion
4
- from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
5
- from .rmsprop import RMSprop
6
- from .rprop import (
7
- BacktrackOnSignChange,
8
- Rprop,
9
- ScaleLRBySignChange,
10
- SignConsistencyLRs,
11
- SignConsistencyMask,
12
- )
13
- from .shampoo import Shampoo
14
- from .soap import SOAP
15
- from .orthograd import OrthoGrad, orthograd_
16
- from .sophia_h import SophiaH
17
- # from .curveball import CurveBall
18
- # from .spectral import SpectralPreconditioner
@@ -1,155 +0,0 @@
1
- from operator import itemgetter
2
- from typing import Literal
3
-
4
- import torch
5
- from ...core import (
6
- Chainable,
7
- Module,
8
- Target,
9
- TensorwiseTransform,
10
- Transform,
11
- Var,
12
- apply_transform,
13
- )
14
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
15
- from ...utils.linalg import matrix_power_eigh
16
- from ..functional import add_power_, lerp_power_, root
17
-
18
-
19
- def adagrad_(
20
- tensors_: TensorList,
21
- sq_sum_: TensorList,
22
- alpha: float | NumberList,
23
- lr_decay: float | NumberList,
24
- eps: float | NumberList,
25
- step: int,
26
- pow: float = 2,
27
- use_sqrt: bool = True,
28
-
29
- # inner args
30
- inner: Module | None = None,
31
- params: list[torch.Tensor] | None = None,
32
- grads: list[torch.Tensor] | None = None,
33
- ):
34
- """returns `tensors_`"""
35
- clr = alpha / (1 + step * lr_decay)
36
-
37
- sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
38
-
39
- if inner is not None:
40
- assert params is not None
41
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
42
-
43
- if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
44
- else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
45
-
46
- return tensors_
47
-
48
-
49
-
50
- class Adagrad(Transform):
51
- """Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
52
-
53
- Args:
54
- lr_decay (float, optional): learning rate decay. Defaults to 0.
55
- initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
56
- eps (float, optional): division epsilon. Defaults to 1e-10.
57
- alpha (float, optional): step size. Defaults to 1.
58
- pow (float, optional): power for gradients and accumulator root. Defaults to 2.
59
- use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
60
- inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
61
- """
62
- def __init__(
63
- self,
64
- lr_decay: float = 0,
65
- initial_accumulator_value: float = 0,
66
- eps: float = 1e-10,
67
- alpha: float = 1,
68
- pow: float = 2,
69
- use_sqrt: bool = True,
70
- inner: Chainable | None = None,
71
- ):
72
- defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
73
- eps = eps, pow=pow, use_sqrt = use_sqrt)
74
- super().__init__(defaults=defaults, uses_grad=False)
75
-
76
- if inner is not None:
77
- self.set_child('inner', inner)
78
-
79
- @torch.no_grad
80
- def apply(self, tensors, params, grads, loss, states, settings):
81
- tensors = TensorList(tensors)
82
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
83
-
84
- lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
85
-
86
- pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
87
-
88
- sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
89
-
90
- # initialize accumulator on 1st step
91
- if step == 1:
92
- sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
93
-
94
- return adagrad_(
95
- tensors,
96
- sq_sum_=sq_sum,
97
- alpha=alpha,
98
- lr_decay=lr_decay,
99
- eps=eps,
100
- step=self.global_state["step"],
101
- pow=pow,
102
- use_sqrt=use_sqrt,
103
-
104
- # inner args
105
- inner=self.children.get("inner", None),
106
- params=params,
107
- grads=grads,
108
- )
109
-
110
-
111
-
112
- class FullMatrixAdagrad(TensorwiseTransform):
113
- def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=False, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', inner: Chainable | None = None):
114
- defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
115
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
116
-
117
- @torch.no_grad
118
- def update_tensor(self, tensor, param, grad, loss, state, settings):
119
- G = tensor.ravel()
120
- GG = torch.outer(G, G)
121
- decay = settings['decay']
122
- beta = settings['beta']
123
- init = settings['init']
124
-
125
- if 'GG' not in state:
126
- if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
127
- elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
128
- elif init == 'ones': state['GG'] = torch.ones_like(GG)
129
- elif init == 'GGT': state['GG'] = GG.clone()
130
- else: raise ValueError(init)
131
- if decay is not None: state['GG'].mul_(decay)
132
-
133
- if beta is not None: state['GG'].lerp_(GG, 1-beta)
134
- else: state['GG'].add_(GG)
135
-
136
- @torch.no_grad
137
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
138
- GG = state['GG']
139
- sqrt = settings['sqrt']
140
-
141
- if tensor.numel() == 1:
142
- GG = GG.squeeze()
143
- if sqrt: return tensor / GG.sqrt()
144
- return tensor / GG
145
-
146
- try:
147
- if sqrt: B = matrix_power_eigh(GG, -1/2)
148
- else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
149
-
150
- except torch.linalg.LinAlgError:
151
- scale = 1 / tensor.abs().max()
152
- return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
153
-
154
- return (B @ tensor.ravel()).view_as(tensor)
155
-