torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,265 +0,0 @@
1
- from collections import deque
2
- from operator import itemgetter
3
- from typing import Any
4
-
5
- import torch
6
-
7
- from ...core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
- from ...utils import NumberList, TensorList, as_tensorlist
9
-
10
-
11
- def _adaptive_damping(
12
- s_k: TensorList,
13
- y_k: TensorList,
14
- ys_k: torch.Tensor,
15
- init_damping = 0.99,
16
- eigval_bounds = (0.01, 1.5)
17
- ):
18
- # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
19
- sigma_l, sigma_h = eigval_bounds
20
- u = ys_k / s_k.dot(s_k)
21
- if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
22
- elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
23
- else: tau = init_damping
24
- y_k = tau * y_k + (1-tau) * s_k
25
- ys_k = s_k.dot(y_k)
26
-
27
- return s_k, y_k, ys_k
28
-
29
- def lbfgs(
30
- tensors_: TensorList,
31
- var: Var,
32
- s_history: deque[TensorList],
33
- y_history: deque[TensorList],
34
- sy_history: deque[torch.Tensor],
35
- y_k: TensorList | None,
36
- ys_k: torch.Tensor | None,
37
- z_tfm: Any,
38
- ):
39
- if len(s_history) == 0 or y_k is None or ys_k is None:
40
-
41
- # initial step size guess modified from pytorch L-BFGS
42
- scale = 1 / tensors_.abs().global_sum()
43
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
44
- return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
45
-
46
- # 1st loop
47
- alpha_list = []
48
- q = tensors_.clone()
49
- for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
- p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
- alpha = p_i * s_i.dot(q)
52
- alpha_list.append(alpha)
53
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
-
55
- # calculate z
56
- # s.y/y.y is also this weird y-looking symbol I couldn't find
57
- # z is it times q
58
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
- z = q * (ys_k / (y_k.dot(y_k)))
60
-
61
- if z_tfm is not None:
62
- z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
63
-
64
- # 2nd loop
65
- for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
66
- p_i = 1 / ys_i
67
- beta_i = p_i * y_i.dot(z)
68
- z.add_(s_i, alpha = alpha_i - beta_i)
69
-
70
- return z
71
-
72
- def _apply_tfms_into_history(
73
- self: Module,
74
- params: list[torch.Tensor],
75
- var: Var,
76
- update: list[torch.Tensor],
77
- ):
78
- if 'params_history_tfm' in self.children:
79
- params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
80
-
81
- if 'grad_history_tfm' in self.children:
82
- update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
83
-
84
- return params, update
85
-
86
- def _apply_tfms_into_precond(
87
- self: Module,
88
- params: list[torch.Tensor],
89
- var: Var,
90
- update: list[torch.Tensor],
91
- ):
92
- if 'params_precond_tfm' in self.children:
93
- params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
94
-
95
- if 'grad_precond_tfm' in self.children:
96
- update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
97
-
98
- return params, update
99
-
100
-
101
- class ModularLBFGS(Module):
102
- """L-BFGS with ability to apply transforms to many inner variables.
103
-
104
- Args:
105
- history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
106
- tol (float | None, optional):
107
- tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
108
- damping (bool, optional):
109
- whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
110
- init_damping (float, optional):
111
- initial damping for adaptive dampening. Defaults to 0.9.
112
- eigval_bounds (tuple, optional):
113
- eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
114
- update_freq (int, optional):
115
- how often to update L-BFGS history. Defaults to 1.
116
- z_tfm (float | None, optional):
117
- transform module applied to initial H^-1 @ q guess. Defaults to None.
118
- params_history_tfm (AnyTransform | None, optional):
119
- transform module applied to params before adding s_k to history. Defaults to None.
120
- grad_history_tfm (AnyTransform | None, optional):
121
- transform module applied to grads before adding y_k to history. Defaults to None.
122
- params_precond_tfm (AnyTransform | None, optional):
123
- transform module applied to params to calculate s_k before preconditioning. Defaults to None.
124
- grad_precond_tfm (AnyTransform | None, optional):
125
- transform module applied to grads to calculate y_k before preconditioning. Defaults to None.
126
- update_precond_tfm (Chainable | None, optional):
127
- transform module applied to grads that are being preconditioned. Defaults to None.
128
- """
129
- def __init__(
130
- self,
131
- history_size=10,
132
- tol: float | None = 1e-10,
133
- damping: bool = False,
134
- init_damping=0.9,
135
- eigval_bounds=(0.5, 50),
136
- update_freq = 1,
137
- params_history_tfm: Chainable | None = None,
138
- grad_history_tfm: Chainable | None = None,
139
- params_precond_tfm: Chainable | None = None,
140
- grad_precond_tfm: Chainable | None = None,
141
- update_precond_tfm: Chainable | None = None,
142
- z_tfm: Chainable | None = None,
143
- ):
144
- defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, update_freq=update_freq)
145
- super().__init__(defaults)
146
-
147
- self.global_state['s_history'] = deque(maxlen=history_size)
148
- self.global_state['y_history'] = deque(maxlen=history_size)
149
- self.global_state['sy_history'] = deque(maxlen=history_size)
150
-
151
- loc = locals().copy()
152
- for k in ('update_precond_tfm', 'params_history_tfm', 'grad_history_tfm', 'params_precond_tfm', 'grad_precond_tfm','z_tfm'):
153
- v = loc[k]
154
- if v is not None:
155
- self.set_child(k,v)
156
-
157
- def reset(self):
158
- """Resets the internal state of the L-SR1 module."""
159
- # super().reset() # Clears self.state (per-parameter) if any, and "step"
160
- self.state.clear()
161
- self.global_state['step'] = 0
162
- self.global_state['s_history'].clear()
163
- self.global_state['y_history'].clear()
164
- self.global_state['sy_history'].clear()
165
-
166
- @torch.no_grad
167
- def step(self, var):
168
- params = as_tensorlist(var.params)
169
- update = as_tensorlist(var.get_update())
170
- step = self.global_state.get('step', 0)
171
- self.global_state['step'] = step + 1
172
-
173
- # history of s and k
174
- s_history: deque[TensorList] = self.global_state['s_history']
175
- y_history: deque[TensorList] = self.global_state['y_history']
176
- sy_history: deque[torch.Tensor] = self.global_state['sy_history']
177
-
178
- tol, damping, init_damping, eigval_bounds, update_freq = itemgetter(
179
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq')(self.settings[params[0]])
180
-
181
- # params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params, cls=NumberList)
182
- # l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
183
-
184
- # params and update that go into history
185
- params_h, update_h = _apply_tfms_into_history(
186
- self,
187
- params=params,
188
- var=var,
189
- update=update,
190
- )
191
-
192
- prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
193
-
194
- # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
195
- if step == 0:
196
- s_k_h = None; y_k_h = None; ys_k_h = None
197
- else:
198
- s_k_h = params_h - prev_params_h
199
- y_k_h = update_h - prev_grad_h
200
- ys_k_h = s_k_h.dot(y_k_h)
201
-
202
- if damping:
203
- s_k_h, y_k_h, ys_k_h = _adaptive_damping(s_k_h, y_k_h, ys_k_h, init_damping=init_damping, eigval_bounds=eigval_bounds)
204
-
205
- prev_params_h.copy_(params_h)
206
- prev_grad_h.copy_(update_h)
207
-
208
- # update effective preconditioning state
209
- if step % update_freq == 0:
210
- if ys_k_h is not None and ys_k_h > 1e-10:
211
- assert s_k_h is not None and y_k_h is not None
212
- s_history.append(s_k_h)
213
- y_history.append(y_k_h)
214
- sy_history.append(ys_k_h)
215
-
216
- # step with inner module before applying preconditioner
217
- if 'update_precond_tfm' in self.children:
218
- update_precond_tfm = self.children['update_precond_tfm']
219
- inner_var = update_precond_tfm.step(var.clone(clone_update=True))
220
- var.update_attrs_from_clone_(inner_var)
221
- tensors = inner_var.update
222
- assert tensors is not None
223
- else:
224
- tensors = update.clone()
225
-
226
- # transforms into preconditioner
227
- params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
228
- prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
229
-
230
- if step == 0:
231
- s_k_p = None; y_k_p = None; ys_k_p = None
232
-
233
- else:
234
- s_k_p = params_p - prev_params_p
235
- y_k_p = update_p - prev_grad_p
236
- ys_k_p = s_k_p.dot(y_k_p)
237
-
238
- if damping:
239
- s_k_p, y_k_p, ys_k_p = _adaptive_damping(s_k_p, y_k_p, ys_k_p, init_damping=init_damping, eigval_bounds=eigval_bounds)
240
-
241
- prev_params_p.copy_(params_p)
242
- prev_grad_p.copy_(update_p)
243
-
244
- # tolerance on gradient difference to avoid exploding after converging
245
- if tol is not None:
246
- if y_k_p is not None and y_k_p.abs().global_max() <= tol:
247
- var.update = update # may have been updated by inner module, probably makes sense to use it here?
248
- return var
249
-
250
- # precondition
251
- dir = lbfgs(
252
- tensors_=as_tensorlist(tensors),
253
- var=var,
254
- s_history=s_history,
255
- y_history=y_history,
256
- sy_history=sy_history,
257
- y_k=y_k_p,
258
- ys_k=ys_k_p,
259
- z_tfm=self.children.get('z_tfm', None),
260
- )
261
-
262
- var.update = dir
263
-
264
- return var
265
-
@@ -1,220 +0,0 @@
1
- import math
2
- from collections.abc import Mapping
3
- from operator import itemgetter
4
-
5
- import torch
6
-
7
- from ...core import Module
8
- from ...utils import TensorList
9
-
10
-
11
-
12
- def adaptive_tracking(
13
- f,
14
- f_0,
15
- f_1,
16
- t_0,
17
- maxiter: int
18
- ):
19
-
20
- t = t_0
21
- f_t = f(t)
22
-
23
- # backtrack
24
- if f_t > f_0:
25
- if f_1 > f_0: t = min(0.5, t_0/2)
26
- while f_t > f_0:
27
- maxiter -= 1
28
- if maxiter < 0: return 0, f_0
29
- t = t/2
30
- f_t = f(t) if t!=1 else f_1
31
- return t, f_t
32
-
33
- # forwardtrack
34
- f_prev = f_t
35
- t *= 2
36
- f_t = f(t)
37
- if f_prev < f_t: return t/2, f_prev
38
- while f_prev >= f_t:
39
- maxiter -= 1
40
- if maxiter < 0: return t, f_t
41
- f_prev = f_t
42
- t *= 2
43
- f_t = f(t)
44
- return t/2, f_prev
45
-
46
-
47
-
48
- class ParabolaSearch(Module):
49
- """"""
50
- def __init__(
51
- self,
52
- step_size: float = 1e-2,
53
- adaptive: bool=True,
54
- normalize: bool=False,
55
- # method: str | None = None,
56
- maxiter: int | None = 10,
57
- # bracket=None,
58
- # bounds=None,
59
- # tol: float | None = None,
60
- # options=None,
61
- ):
62
- if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
63
- defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
64
- super().__init__(defaults)
65
-
66
- import scipy.optimize
67
- self.scopt = scipy.optimize
68
-
69
-
70
- @torch.no_grad
71
- def step(self, var):
72
- x_0 = TensorList(var.params)
73
- closure = var.closure
74
- assert closure is not None
75
- settings = self.settings[x_0[0]]
76
- step_size = settings['step_size']
77
- adaptive = settings['adaptive']
78
- normalize = settings['normalize']
79
- maxiter = settings['maxiter']
80
- if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
81
-
82
- grad = TensorList(var.get_grad())
83
- f_0 = var.get_loss(False)
84
-
85
- scale = 1
86
- if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
87
- if adaptive: scale = grad.abs().mean().clip(min=1e-8)
88
-
89
- # make step
90
- v_0 = grad * (step_size/scale)
91
- x_0 -= v_0
92
- with torch.enable_grad():
93
- f_1 = closure()
94
- grad = x_0.grad
95
-
96
- x_0 += v_0
97
- if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
98
- v_1 = grad * (step_size/scale)
99
- a = v_1 - v_0
100
-
101
- def parabolic_objective(t: float):
102
- nonlocal x_0
103
-
104
- step = v_0*t + 0.5*a*t**2
105
- x_0 -= step
106
- value = closure(False)
107
- x_0 += step
108
- return value.detach().cpu()
109
-
110
- prev_t = self.global_state.get('prev_t', 2)
111
- t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
112
- self.global_state['prev_t'] = t
113
-
114
- # method, bracket, bounds, tol, options, maxiter = itemgetter(
115
- # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
116
-
117
- # if maxiter is not None:
118
- # options = dict(options) if isinstance(options, Mapping) else {}
119
- # options['maxiter'] = maxiter
120
-
121
- # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
122
- # t = res.x
123
-
124
- var.update = v_0*t + 0.5*a*t**2
125
- return var
126
-
127
- class CubicParabolaSearch(Module):
128
- """"""
129
- def __init__(
130
- self,
131
- step_size: float = 1e-2,
132
- adaptive: bool=True,
133
- normalize: bool=False,
134
- # method: str | None = None,
135
- maxiter: int | None = 10,
136
- # bracket=None,
137
- # bounds=None,
138
- # tol: float | None = None,
139
- # options=None,
140
- ):
141
- if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
142
- defaults = dict(step_size=step_size, adaptive=adaptive, normalize=normalize, maxiter=maxiter)
143
- super().__init__(defaults)
144
-
145
- import scipy.optimize
146
- self.scopt = scipy.optimize
147
-
148
-
149
- @torch.no_grad
150
- def step(self, var):
151
- x_0 = TensorList(var.params)
152
- closure = var.closure
153
- assert closure is not None
154
- settings = self.settings[x_0[0]]
155
- step_size = settings['step_size']
156
- adaptive = settings['adaptive']
157
- maxiter = settings['maxiter']
158
- normalize = settings['normalize']
159
- if normalize and adaptive: raise ValueError("pick either normalize or adaptive")
160
-
161
- grad = TensorList(var.get_grad())
162
- f_0 = var.get_loss(False)
163
-
164
- scale = 1
165
- if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
166
- if adaptive: scale = grad.abs().mean().clip(min=1e-8)
167
-
168
- # make step
169
- v_0 = grad * (step_size/scale)
170
- x_0 -= v_0
171
- with torch.enable_grad():
172
- f_1 = closure()
173
- grad = x_0.grad
174
-
175
- if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
176
- v_1 = grad * (step_size/scale)
177
- a_0 = v_1 - v_0
178
-
179
- # make another step
180
- x_0 -= v_1
181
- with torch.enable_grad():
182
- f_2 = closure()
183
- grad = x_0.grad
184
-
185
- if normalize: grad = grad/grad.abs().mean().clip(min=1e-8)
186
- v_2 = grad * (step_size/scale)
187
- a_1 = v_2 - v_1
188
-
189
- j = a_1 - a_0
190
-
191
- x_0 += v_0
192
- x_0 += v_1
193
-
194
- def parabolic_objective(t: float):
195
- nonlocal x_0
196
-
197
- step = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
198
- x_0 -= step
199
- value = closure(False)
200
- x_0 += step
201
- return value
202
-
203
-
204
- prev_t = self.global_state.get('prev_t', 2)
205
- t, f = adaptive_tracking(parabolic_objective, f_0=f_0, f_1=f_1, t_0=prev_t, maxiter=maxiter)
206
- self.global_state['prev_t'] = t
207
-
208
- # method, bracket, bounds, tol, options, maxiter = itemgetter(
209
- # 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
210
-
211
- # if maxiter is not None:
212
- # options = dict(options) if isinstance(options, Mapping) else {}
213
- # options['maxiter'] = maxiter
214
-
215
- # res = self.scopt.minimize_scalar(parabolic_objective, method=method, bracket=bracket, bounds=bounds, tol=tol, options=options)
216
- # t = res.x
217
-
218
- var.update = v_0*t + (1/2)*a_0*t**2 + (1/6)*j*t**3
219
- return var
220
-
@@ -1,145 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
- # import visualbench as vb
5
-
6
- # import torchzero as tz
7
-
8
- from ...core import Transform, Chainable, apply_transform
9
- from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
10
- from ...utils import TensorList, vec_to_tensors_
11
-
12
-
13
- def inverse_sqrt(M):
14
- if M.shape[-1] == 2: return inv_sqrt_2x2(M, force_pd=True) # general formula for 2x2 matrices
15
- return matrix_power_eigh(M, -1/2)
16
-
17
- def update_subspace_preconditioner_(
18
- grad: torch.Tensor, # store grads and basis as vectors for matmul
19
- basis: torch.Tensor, # ndim, k
20
- accumulator_: torch.Tensor, # k, k
21
- beta: float | None,
22
- ):
23
- projected = basis.T @ grad # k
24
- outer = torch.outer(projected, projected)
25
-
26
- if beta is None: accumulator_.add_(outer)
27
- else: accumulator_.lerp_(outer, 1-beta)
28
-
29
- def apply_subspace_preconditioner(
30
- tensor: torch.Tensor,
31
- basis: torch.Tensor, # ndim, k
32
- accumulator: torch.Tensor,
33
- ):
34
- preconditioner = inverse_sqrt(accumulator) # k,k
35
-
36
- tensor_projected = basis.T @ tensor # k
37
- update_projected = preconditioner @ tensor_projected # k
38
- return basis @ update_projected # d
39
-
40
- class RandomSubspacePreconditioning(Transform):
41
- """Whitens in random slowly changing subspace.
42
-
43
- .. warning::
44
- Experimental and this is a barebones implementation.
45
-
46
- """
47
- def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
48
- defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
49
- super().__init__(defaults, uses_grad=False)
50
-
51
- if inner is not None: self.set_child('inner', inner)
52
-
53
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
54
- settings = settings[0]
55
- g = torch.cat([t.view(-1) for t in tensors])
56
- k = settings['k']
57
- beta = settings['beta']
58
- basis_beta = settings['basis_beta']
59
-
60
- if 'basis' not in self.global_state:
61
- self.global_state['basis'] = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
62
- self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
63
-
64
- basis = self.global_state['basis']
65
- accumulator = self.global_state['accumulator']
66
-
67
- if basis_beta is not None:
68
- basis.lerp_(torch.randn_like(basis), 1-basis_beta)
69
-
70
- update_subspace_preconditioner_(g, basis, accumulator, beta)
71
-
72
- if 'inner' in self.children:
73
- tensors = apply_transform(self.children['inner'], tensors, params, grads)
74
- g = torch.cat([t.view(-1) for t in tensors])
75
-
76
- try:
77
- preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
78
- except torch.linalg.LinAlgError:
79
- preconditioned = g.clip(-0.1, 0.1)
80
- vec_to_tensors_(preconditioned, tensors)
81
-
82
- return tensors
83
-
84
-
85
- class HistorySubspacePreconditioning(Transform):
86
- """Whitens in subspace spanned by history of gradient differences.
87
-
88
- .. warning::
89
- Experimental and this is a barebones implementation.
90
-
91
- Args:
92
- beta - for preconditioner itself in the basis.
93
- basis_beta - how much basis is allowed to change.
94
- """
95
- def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
96
- defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
97
- super().__init__(defaults, uses_grad=False)
98
-
99
- if inner is not None: self.set_child('inner', inner)
100
-
101
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
102
- settings = settings[0]
103
-
104
- g = torch.cat([t.view(-1) for t in tensors])
105
- k = settings['k']
106
- beta = settings['beta']
107
- basis_beta = settings['basis_beta']
108
-
109
- if 'history' not in self.global_state:
110
- self.global_state['history'] = deque(maxlen=k)
111
- self.global_state['accumulator'] = torch.eye(k, device=g.device, dtype=g.dtype)
112
- self.global_state['basis'] = torch.ones(g.numel(), k, device=g.device, dtype=g.dtype)
113
-
114
-
115
- history: deque = self.global_state['history']
116
- accumulator = self.global_state['accumulator']
117
- basis = self.global_state['basis']
118
-
119
- history.append(g)
120
- if len(history) < k:
121
- basis_t = torch.randn(g.numel(), k, device=g.device, dtype=g.dtype)
122
- history_basis = torch.stack(tuple(history), -1)
123
- basis_t[:, -len(history):] = history_basis
124
-
125
- else:
126
- basis_t = torch.stack(tuple(history), -1)
127
-
128
- basis_t[:,:-1] = basis_t[:, :-1] - basis_t[:, 1:]
129
- basis_t = (basis_t - basis_t.mean()) / basis_t.std()
130
-
131
- basis.lerp_(basis_t, 1-basis_beta)
132
- update_subspace_preconditioner_(g, basis, accumulator, beta)
133
-
134
- if 'inner' in self.children:
135
- tensors = apply_transform(self.children['inner'], tensors, params, grads)
136
- g = torch.cat([t.view(-1) for t in tensors])
137
-
138
- try:
139
- preconditioned = apply_subspace_preconditioner(g, basis, accumulator)
140
- except torch.linalg.LinAlgError:
141
- preconditioned = g.clip(-0.1,0.1)
142
- vec_to_tensors_(preconditioned, tensors)
143
-
144
- return tensors
145
-
@@ -1,42 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
-
5
- from ...core import Chainable, TensorwiseTransform
6
- from ...utils.linalg import matrix_power_eigh
7
-
8
-
9
- class TensorAdagrad(TensorwiseTransform):
10
- """3rd order whitening (maybe normalizes skewness, but don't quote me on it).
11
-
12
- .. warning::
13
- Experimental.
14
- """
15
- def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
16
- defaults = dict(history_size=history_size, reg=reg)
17
- super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
18
-
19
- @torch.no_grad
20
- def update_tensor(self, tensor, param, grad, loss, state, setting):
21
- reg = setting['reg']
22
- if 'history' not in state:
23
- state['history'] = deque(maxlen=setting['history_size'])
24
-
25
- g = tensor.view(-1)
26
- history = state['history']
27
- history.append(g.clone())
28
-
29
- I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
30
- g_k = history[0]
31
- outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
32
- if len(history) > 1:
33
- for g_k in list(history)[1:]:
34
- outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
35
-
36
- state['outer'] = outer.add_(I)
37
-
38
- @torch.no_grad
39
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
40
- outer = state['outer']
41
- P = matrix_power_eigh(outer, -1/2)
42
- return (P @ tensor.ravel()).view_as(tensor)