torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,265 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from ....core import Chainable, Module, Transform, Vars, apply, maybe_chain
8
+ from ....utils import NumberList, TensorList, as_tensorlist
9
+
10
+
11
+ def _adaptive_damping(
12
+ s_k: TensorList,
13
+ y_k: TensorList,
14
+ ys_k: torch.Tensor,
15
+ init_damping = 0.99,
16
+ eigval_bounds = (0.01, 1.5)
17
+ ):
18
+ # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
19
+ sigma_l, sigma_h = eigval_bounds
20
+ u = ys_k / s_k.dot(s_k)
21
+ if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
22
+ elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
23
+ else: tau = init_damping
24
+ y_k = tau * y_k + (1-tau) * s_k
25
+ ys_k = s_k.dot(y_k)
26
+
27
+ return s_k, y_k, ys_k
28
+
29
+ def lbfgs(
30
+ tensors_: TensorList,
31
+ vars: Vars,
32
+ s_history: deque[TensorList],
33
+ y_history: deque[TensorList],
34
+ sy_history: deque[torch.Tensor],
35
+ y_k: TensorList | None,
36
+ ys_k: torch.Tensor | None,
37
+ z_tfm: Any,
38
+ ):
39
+ if len(s_history) == 0 or y_k is None or ys_k is None:
40
+ # dir = params.grad.sign() # may work fine
41
+
42
+ # initial step size guess taken from pytorch L-BFGS
43
+ return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
44
+
45
+ else:
46
+ # 1st loop
47
+ alpha_list = []
48
+ q = tensors_.clone()
49
+ for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
+ p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
+ alpha = p_i * s_i.dot(q)
52
+ alpha_list.append(alpha)
53
+ q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
+
55
+ # calculate z
56
+ # s.y/y.y is also this weird y-looking symbol I couldn't find
57
+ # z is it times q
58
+ # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
+ z = q * (ys_k / (y_k.dot(y_k)))
60
+
61
+ if z_tfm is not None:
62
+ z = TensorList(apply(z_tfm, tensors=z, params=vars.params, grads=vars.grad, vars=vars))
63
+
64
+ # 2nd loop
65
+ for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
66
+ p_i = 1 / ys_i
67
+ beta_i = p_i * y_i.dot(z)
68
+ z.add_(s_i, alpha = alpha_i - beta_i)
69
+
70
+ return z
71
+
72
+ def _apply_tfms_into_history(
73
+ self: Module,
74
+ params: list[torch.Tensor],
75
+ vars: Vars,
76
+ update: list[torch.Tensor],
77
+ ):
78
+ if 'params_history_tfm' in self.children:
79
+ params = apply(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
80
+
81
+ if 'grad_history_tfm' in self.children:
82
+ update = apply(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=vars.grad, vars=vars)
83
+
84
+ return params, update
85
+
86
+ def _apply_tfms_into_precond(
87
+ self: Module,
88
+ params: list[torch.Tensor],
89
+ vars: Vars,
90
+ update: list[torch.Tensor],
91
+ ):
92
+ if 'params_precond_tfm' in self.children:
93
+ params = apply(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
94
+
95
+ if 'grad_precond_tfm' in self.children:
96
+ update = apply(self.children['grad_precond_tfm'], tensors=update, params=params, grads=vars.grad, vars=vars)
97
+
98
+ return params, update
99
+
100
+
101
+ class ModularLBFGS(Module):
102
+ """L-BFGS with ability to apply transforms to many inner variables.
103
+
104
+ Args:
105
+ history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
106
+ tol (float | None, optional):
107
+ tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
108
+ damping (bool, optional):
109
+ whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
110
+ init_damping (float, optional):
111
+ initial damping for adaptive dampening. Defaults to 0.9.
112
+ eigval_bounds (tuple, optional):
113
+ eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
114
+ update_freq (int, optional):
115
+ how often to update L-BFGS history. Defaults to 1.
116
+ z_tfm (float | None, optional):
117
+ transform module applied to initial H^-1 @ q guess. Defaults to None.
118
+ params_history_tfm (AnyTransform | None, optional):
119
+ transform module applied to params before adding s_k to history. Defaults to None.
120
+ grad_history_tfm (AnyTransform | None, optional):
121
+ transform module applied to grads before adding y_k to history. Defaults to None.
122
+ params_precond_tfm (AnyTransform | None, optional):
123
+ transform module applied to params to calculate s_k before preconditioning. Defaults to None.
124
+ grad_precond_tfm (AnyTransform | None, optional):
125
+ transform module applied to grads to calculate y_k before preconditioning. Defaults to None.
126
+ update_precond_tfm (Chainable | None, optional):
127
+ transform module applied to grads that are being preconditioned. Defaults to None.
128
+ """
129
+ def __init__(
130
+ self,
131
+ history_size=10,
132
+ tol: float | None = 1e-10,
133
+ damping: bool = False,
134
+ init_damping=0.9,
135
+ eigval_bounds=(0.5, 50),
136
+ update_freq = 1,
137
+ params_history_tfm: Chainable | None = None,
138
+ grad_history_tfm: Chainable | None = None,
139
+ params_precond_tfm: Chainable | None = None,
140
+ grad_precond_tfm: Chainable | None = None,
141
+ update_precond_tfm: Chainable | None = None,
142
+ z_tfm: Chainable | None = None,
143
+ ):
144
+ defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, update_freq=update_freq)
145
+ super().__init__(defaults)
146
+
147
+ self.global_state['s_history'] = deque(maxlen=history_size)
148
+ self.global_state['y_history'] = deque(maxlen=history_size)
149
+ self.global_state['sy_history'] = deque(maxlen=history_size)
150
+
151
+ loc = locals().copy()
152
+ for k in ('update_precond_tfm', 'params_history_tfm', 'grad_history_tfm', 'params_precond_tfm', 'grad_precond_tfm','z_tfm'):
153
+ v = loc[k]
154
+ if v is not None:
155
+ self.set_child(k,v)
156
+
157
+ def reset(self):
158
+ """Resets the internal state of the L-SR1 module."""
159
+ # super().reset() # Clears self.state (per-parameter) if any, and "step"
160
+ self.state.clear()
161
+ self.global_state['step'] = 0
162
+ self.global_state['s_history'].clear()
163
+ self.global_state['y_history'].clear()
164
+ self.global_state['sy_history'].clear()
165
+
166
+ @torch.no_grad
167
+ def step(self, vars):
168
+ params = as_tensorlist(vars.params)
169
+ update = as_tensorlist(vars.get_update())
170
+ step = self.global_state.get('step', 0)
171
+ self.global_state['step'] = step + 1
172
+
173
+ # history of s and k
174
+ s_history: deque[TensorList] = self.global_state['s_history']
175
+ y_history: deque[TensorList] = self.global_state['y_history']
176
+ sy_history: deque[torch.Tensor] = self.global_state['sy_history']
177
+
178
+ tol, damping, init_damping, eigval_bounds, update_freq = itemgetter(
179
+ 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq')(self.settings[params[0]])
180
+
181
+ # params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params, cls=NumberList)
182
+ # l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
183
+
184
+ # params and update that go into history
185
+ params_h, update_h = _apply_tfms_into_history(
186
+ self,
187
+ params=params,
188
+ vars=vars,
189
+ update=update,
190
+ )
191
+
192
+ prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h', params=params, cls=TensorList)
193
+
194
+ # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
195
+ if step == 0:
196
+ s_k_h = None; y_k_h = None; ys_k_h = None
197
+ else:
198
+ s_k_h = params_h - prev_params_h
199
+ y_k_h = update_h - prev_grad_h
200
+ ys_k_h = s_k_h.dot(y_k_h)
201
+
202
+ if damping:
203
+ s_k_h, y_k_h, ys_k_h = _adaptive_damping(s_k_h, y_k_h, ys_k_h, init_damping=init_damping, eigval_bounds=eigval_bounds)
204
+
205
+ prev_params_h.copy_(params_h)
206
+ prev_grad_h.copy_(update_h)
207
+
208
+ # update effective preconditioning state
209
+ if step % update_freq == 0:
210
+ if ys_k_h is not None and ys_k_h > 1e-10:
211
+ assert s_k_h is not None and y_k_h is not None
212
+ s_history.append(s_k_h)
213
+ y_history.append(y_k_h)
214
+ sy_history.append(ys_k_h)
215
+
216
+ # step with inner module before applying preconditioner
217
+ if 'update_precond_tfm' in self.children:
218
+ update_precond_tfm = self.children['update_precond_tfm']
219
+ inner_vars = update_precond_tfm.step(vars.clone(clone_update=True))
220
+ vars.update_attrs_from_clone_(inner_vars)
221
+ tensors = inner_vars.update
222
+ assert tensors is not None
223
+ else:
224
+ tensors = update.clone()
225
+
226
+ # transforms into preconditioner
227
+ params_p, update_p = _apply_tfms_into_precond(self, params=params, vars=vars, update=update)
228
+ prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p', params=params, cls=TensorList)
229
+
230
+ if step == 0:
231
+ s_k_p = None; y_k_p = None; ys_k_p = None
232
+
233
+ else:
234
+ s_k_p = params_p - prev_params_p
235
+ y_k_p = update_p - prev_grad_p
236
+ ys_k_p = s_k_p.dot(y_k_p)
237
+
238
+ if damping:
239
+ s_k_p, y_k_p, ys_k_p = _adaptive_damping(s_k_p, y_k_p, ys_k_p, init_damping=init_damping, eigval_bounds=eigval_bounds)
240
+
241
+ prev_params_p.copy_(params_p)
242
+ prev_grad_p.copy_(update_p)
243
+
244
+ # tolerance on gradient difference to avoid exploding after converging
245
+ if tol is not None:
246
+ if y_k_p is not None and y_k_p.abs().global_max() <= tol:
247
+ vars.update = update # may have been updated by inner module, probably makes sense to use it here?
248
+ return vars
249
+
250
+ # precondition
251
+ dir = lbfgs(
252
+ tensors_=as_tensorlist(tensors),
253
+ vars=vars,
254
+ s_history=s_history,
255
+ y_history=y_history,
256
+ sy_history=sy_history,
257
+ y_k=y_k_p,
258
+ ys_k=ys_k_p,
259
+ z_tfm=self.children.get('z_tfm', None),
260
+ )
261
+
262
+ vars.update = dir
263
+
264
+ return vars
265
+
@@ -0,0 +1,228 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ import torch
4
+
5
+ from ...core import Transform, Chainable, Module, Vars, apply
6
+ from ...utils import TensorList, as_tensorlist, NumberList
7
+
8
+
9
+ def _adaptive_damping(
10
+ s_k: TensorList,
11
+ y_k: TensorList,
12
+ ys_k: torch.Tensor,
13
+ init_damping = 0.99,
14
+ eigval_bounds = (0.01, 1.5)
15
+ ):
16
+ # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
17
+ sigma_l, sigma_h = eigval_bounds
18
+ u = ys_k / s_k.dot(s_k)
19
+ if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
20
+ elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
21
+ else: tau = init_damping
22
+ y_k = tau * y_k + (1-tau) * s_k
23
+ ys_k = s_k.dot(y_k)
24
+
25
+ return s_k, y_k, ys_k
26
+
27
+ def lbfgs(
28
+ tensors_: TensorList,
29
+ s_history: deque[TensorList],
30
+ y_history: deque[TensorList],
31
+ sy_history: deque[torch.Tensor],
32
+ y_k: TensorList | None,
33
+ ys_k: torch.Tensor | None,
34
+ z_beta: float | None,
35
+ z_ema: TensorList | None,
36
+ step: int,
37
+ ):
38
+ if len(s_history) == 0 or y_k is None or ys_k is None:
39
+ # dir = params.grad.sign() # may work fine
40
+
41
+ # initial step size guess taken from pytorch L-BFGS
42
+ return tensors_.mul_(min(1.0, 1.0 / tensors_.abs().global_sum())) # pyright: ignore[reportArgumentType]
43
+
44
+ else:
45
+ # 1st loop
46
+ alpha_list = []
47
+ q = tensors_.clone()
48
+ for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
49
+ p_i = 1 / ys_i # this is also denoted as ρ (rho)
50
+ alpha = p_i * s_i.dot(q)
51
+ alpha_list.append(alpha)
52
+ q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
53
+
54
+ # calculate z
55
+ # s.y/y.y is also this weird y-looking symbol I couldn't find
56
+ # z is it times q
57
+ # actually H0 = (s.y/y.y) * I, and z = H0 @ q
58
+ z = q * (ys_k / (y_k.dot(y_k)))
59
+
60
+ # an attempt into adding momentum, lerping initial z seems stable compared to other variables
61
+ if z_beta is not None:
62
+ assert z_ema is not None
63
+ if step == 0: z_ema.copy_(z)
64
+ else: z_ema.lerp(z, 1-z_beta)
65
+ z = z_ema
66
+
67
+ # 2nd loop
68
+ for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
69
+ p_i = 1 / ys_i
70
+ beta_i = p_i * y_i.dot(z)
71
+ z.add_(s_i, alpha = alpha_i - beta_i)
72
+
73
+ return z
74
+
75
+ def _lerp_params_update_(
76
+ self_: Module,
77
+ params: list[torch.Tensor],
78
+ update: list[torch.Tensor],
79
+ params_beta: list[float | None],
80
+ grads_beta: list[float | None],
81
+ ):
82
+ for i, (p, u, p_beta, u_beta) in enumerate(zip(params.copy(), update.copy(), params_beta, grads_beta)):
83
+ if p_beta is not None or u_beta is not None:
84
+ state = self_.state[p]
85
+
86
+ if p_beta is not None:
87
+ if 'param_ema' not in state: state['param_ema'] = p.clone()
88
+ else: state['param_ema'].lerp_(p, 1-p_beta)
89
+ params[i] = state['param_ema']
90
+
91
+ if u_beta is not None:
92
+ if 'grad_ema' not in state: state['grad_ema'] = u.clone()
93
+ else: state['grad_ema'].lerp_(u, 1-u_beta)
94
+ update[i] = state['grad_ema']
95
+
96
+ return TensorList(params), TensorList(update)
97
+
98
+ class LBFGS(Module):
99
+ """L-BFGS
100
+
101
+ Args:
102
+ history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
103
+ tol (float | None, optional):
104
+ tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
105
+ damping (bool, optional):
106
+ whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
107
+ init_damping (float, optional):
108
+ initial damping for adaptive dampening. Defaults to 0.9.
109
+ eigval_bounds (tuple, optional):
110
+ eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
111
+ params_beta (float | None, optional):
112
+ if not None, EMA of parameters is used for preconditioner update. Defaults to None.
113
+ grads_beta (float | None, optional):
114
+ if not None, EMA of gradients is used for preconditioner update. Defaults to None.
115
+ update_freq (int, optional):
116
+ how often to update L-BFGS history. Defaults to 1.
117
+ z_beta (float | None, optional):
118
+ optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
119
+ tol_reset (bool, optional):
120
+ If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
121
+ inner (Chainable | None, optional):
122
+ optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
123
+ """
124
+ def __init__(
125
+ self,
126
+ history_size=10,
127
+ tol: float | None = 1e-10,
128
+ damping: bool = False,
129
+ init_damping=0.9,
130
+ eigval_bounds=(0.5, 50),
131
+ params_beta: float | None = None,
132
+ grads_beta: float | None = None,
133
+ update_freq = 1,
134
+ z_beta: float | None = None,
135
+ tol_reset: bool = False,
136
+ inner: Chainable | None = None,
137
+ ):
138
+ defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
139
+ super().__init__(defaults)
140
+
141
+ self.global_state['s_history'] = deque(maxlen=history_size)
142
+ self.global_state['y_history'] = deque(maxlen=history_size)
143
+ self.global_state['sy_history'] = deque(maxlen=history_size)
144
+
145
+ if inner is not None:
146
+ self.set_child('inner', inner)
147
+
148
+ def reset(self):
149
+ self.state.clear()
150
+ self.global_state['step'] = 0
151
+ self.global_state['s_history'].clear()
152
+ self.global_state['y_history'].clear()
153
+ self.global_state['sy_history'].clear()
154
+
155
+ @torch.no_grad
156
+ def step(self, vars):
157
+ params = as_tensorlist(vars.params)
158
+ update = as_tensorlist(vars.get_update())
159
+ step = self.global_state.get('step', 0)
160
+ self.global_state['step'] = step + 1
161
+
162
+ # history of s and k
163
+ s_history: deque[TensorList] = self.global_state['s_history']
164
+ y_history: deque[TensorList] = self.global_state['y_history']
165
+ sy_history: deque[torch.Tensor] = self.global_state['sy_history']
166
+
167
+ tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
168
+ 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
169
+ params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params)
170
+
171
+ l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
172
+ prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
173
+
174
+ # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
175
+ if step == 0:
176
+ s_k = None; y_k = None; ys_k = None
177
+ else:
178
+ s_k = l_params - prev_l_params
179
+ y_k = l_update - prev_l_grad
180
+ ys_k = s_k.dot(y_k)
181
+
182
+ if damping:
183
+ s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
184
+
185
+ prev_l_params.copy_(l_params)
186
+ prev_l_grad.copy_(l_update)
187
+
188
+ # update effective preconditioning state
189
+ if step % update_freq == 0:
190
+ if ys_k is not None and ys_k > 1e-10:
191
+ assert s_k is not None and y_k is not None
192
+ s_history.append(s_k)
193
+ y_history.append(y_k)
194
+ sy_history.append(ys_k)
195
+
196
+ # step with inner module before applying preconditioner
197
+ if self.children:
198
+ update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
199
+
200
+ # tolerance on gradient difference to avoid exploding after converging
201
+ if tol is not None:
202
+ if y_k is not None and y_k.abs().global_max() <= tol:
203
+ vars.update = update # may have been updated by inner module, probably makes sense to use it here?
204
+ if tol_reset: self.reset()
205
+ return vars
206
+
207
+ # lerp initial H^-1 @ q guess
208
+ z_ema = None
209
+ if z_beta is not None:
210
+ z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
211
+
212
+ # precondition
213
+ dir = lbfgs(
214
+ tensors_=as_tensorlist(update),
215
+ s_history=s_history,
216
+ y_history=y_history,
217
+ sy_history=sy_history,
218
+ y_k=y_k,
219
+ ys_k=ys_k,
220
+ z_beta = z_beta,
221
+ z_ema = z_ema,
222
+ step=step
223
+ )
224
+
225
+ vars.update = dir
226
+
227
+ return vars
228
+
@@ -0,0 +1,170 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+
4
+ import torch
5
+
6
+ from ...core import Chainable, Module, Transform, Vars, apply
7
+ from ...utils import NumberList, TensorList, as_tensorlist
8
+
9
+ from .lbfgs import _lerp_params_update_
10
+
11
+ def lsr1_(
12
+ tensors_: TensorList,
13
+ s_history: deque[TensorList],
14
+ y_history: deque[TensorList],
15
+ step: int,
16
+ scale_second: bool,
17
+ ):
18
+ if step == 0 or not s_history:
19
+ # initial step size guess from pytorch
20
+ tensors_.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
21
+ return tensors_
22
+
23
+ m = len(s_history)
24
+
25
+ w_list: list[TensorList] = []
26
+ ww_list: list = [None for _ in range(m)]
27
+ wy_list: list = [None for _ in range(m)]
28
+
29
+ # 1st loop - all w_k = s_k - H_k_prev y_k
30
+ for k in range(m):
31
+ s_k = s_history[k]
32
+ y_k = y_history[k]
33
+
34
+ H_k = y_k.clone()
35
+ for j in range(k):
36
+ w_j = w_list[j]
37
+ y_j = y_history[j]
38
+
39
+ wy = wy_list[j]
40
+ if wy is None: wy = wy_list[j] = w_j.dot(y_j)
41
+
42
+ ww = ww_list[j]
43
+ if ww is None: ww = ww_list[j] = w_j.dot(w_j)
44
+
45
+ if wy == 0: continue
46
+
47
+ H_k.add_(w_j, alpha=w_j.dot(y_k) / wy) # pyright:ignore[reportArgumentType]
48
+
49
+ w_k = s_k - H_k
50
+ w_list.append(w_k)
51
+
52
+ Hx = tensors_.clone()
53
+ for k in range(m):
54
+ w_k = w_list[k]
55
+ y_k = y_history[k]
56
+ wy = wy_list[k]
57
+ ww = ww_list[k]
58
+
59
+ if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
60
+ if ww is None: ww = w_k.dot(w_k)
61
+
62
+ if wy == 0: continue
63
+
64
+ Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
65
+
66
+ if scale_second and step == 1:
67
+ Hx.div_(max(1.0, tensors_.abs().global_sum())) # pyright:ignore[reportArgumentType]
68
+ return Hx
69
+
70
+
71
+ class LSR1(Module):
72
+ """Limited Memory SR1 (L-SR1)
73
+ Args:
74
+ history_size (int, optional): Number of past parameter differences (s)
75
+ and gradient differences (y) to store. Defaults to 10.
76
+ skip_R_val (float, optional): Tolerance R for the SR1 update skip condition
77
+ |w_k^T y_k| >= R * ||w_k|| * ||y_k||. Defaults to 1e-8.
78
+ Updates where this condition is not met are skipped during history accumulation
79
+ and matrix-vector products.
80
+ params_beta (float | None, optional): If not None, EMA of parameters is used for
81
+ preconditioner update (s_k vector). Defaults to None.
82
+ grads_beta (float | None, optional): If not None, EMA of gradients is used for
83
+ preconditioner update (y_k vector). Defaults to None.
84
+ update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
85
+ conv_tol (float | None, optional): Tolerance for y_k norm. If max abs value of y_k
86
+ is below this, the preconditioning step might be skipped, assuming convergence.
87
+ Defaults to 1e-10.
88
+ inner (Chainable | None, optional): Optional inner modules applied after updating
89
+ L-SR1 history and before preconditioning. Defaults to None.
90
+ """
91
+ def __init__(
92
+ self,
93
+ history_size: int = 10,
94
+ tol: float = 1e-8,
95
+ params_beta: float | None = None,
96
+ grads_beta: float | None = None,
97
+ update_freq: int = 1,
98
+ scale_second: bool = True,
99
+ inner: Chainable | None = None,
100
+ ):
101
+ defaults = dict(
102
+ history_size=history_size, tol=tol,
103
+ params_beta=params_beta, grads_beta=grads_beta,
104
+ update_freq=update_freq, scale_second=scale_second
105
+ )
106
+ super().__init__(defaults)
107
+
108
+ self.global_state['s_history'] = deque(maxlen=history_size)
109
+ self.global_state['y_history'] = deque(maxlen=history_size)
110
+
111
+ if inner is not None:
112
+ self.set_child('inner', inner)
113
+
114
+ def reset(self):
115
+ self.state.clear()
116
+ self.global_state['step'] = 0
117
+ self.global_state['s_history'].clear()
118
+ self.global_state['y_history'].clear()
119
+
120
+
121
+ @torch.no_grad
122
+ def step(self, vars: Vars):
123
+ params = as_tensorlist(vars.params)
124
+ update = as_tensorlist(vars.get_update())
125
+ step = self.global_state.get('step', 0)
126
+ self.global_state['step'] = step + 1
127
+
128
+ s_history: deque[TensorList] = self.global_state['s_history']
129
+ y_history: deque[TensorList] = self.global_state['y_history']
130
+
131
+ settings = self.settings[params[0]]
132
+ tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
133
+
134
+ params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta', params=params) # type: ignore
135
+ l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
136
+
137
+ prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
138
+
139
+ y_k = None
140
+ if step != 0:
141
+ if step % update_freq == 0:
142
+ s_k = l_params - prev_l_params
143
+ y_k = l_update - prev_l_grad
144
+
145
+ s_history.append(s_k)
146
+ y_history.append(y_k)
147
+
148
+ prev_l_params.copy_(l_params)
149
+ prev_l_grad.copy_(l_update)
150
+
151
+ if 'inner' in self.children:
152
+ update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
153
+
154
+ # tolerance on gradient difference to avoid exploding after converging
155
+ if tol is not None:
156
+ if y_k is not None and y_k.abs().global_max() <= tol:
157
+ vars.update = update
158
+ return vars
159
+
160
+ dir = lsr1_(
161
+ tensors_=update,
162
+ s_history=s_history,
163
+ y_history=y_history,
164
+ step=step,
165
+ scale_second=scale_second,
166
+ )
167
+
168
+ vars.update = dir
169
+
170
+ return vars