torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,196 @@
1
+ from collections import deque
2
+ from functools import partial
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Chainable, Module, Transform, Vars, apply
9
+ from ...utils import NumberList, TensorList, as_tensorlist
10
+ from .lbfgs import _adaptive_damping, lbfgs
11
+
12
+
13
+ @torch.no_grad
14
+ def _store_sk_yk_after_step_hook(optimizer, vars: Vars, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
+ assert vars.closure is not None
16
+ with torch.enable_grad(): vars.closure()
17
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in vars.params]
18
+ s_k = vars.params - prev_params
19
+ y_k = grad - prev_grad
20
+ ys_k = s_k.dot(y_k)
21
+
22
+ if damping:
23
+ s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
24
+
25
+ if ys_k > 1e-10:
26
+ s_history.append(s_k)
27
+ y_history.append(y_k)
28
+ sy_history.append(ys_k)
29
+
30
+
31
+
32
+ class OnlineLBFGS(Module):
33
+ """Online L-BFGS.
34
+ Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
35
+ However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
36
+ better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
37
+
38
+ Args:
39
+ history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
40
+ sample_grads (str, optional):
41
+ - "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
42
+ and adds it to history before stepping.
43
+ - "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
44
+ s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
45
+
46
+ In practice both modes behave very similarly. Defaults to 'before'.
47
+ tol (float | None, optional):
48
+ tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
49
+ damping (bool, optional):
50
+ whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
51
+ init_damping (float, optional):
52
+ initial damping for adaptive dampening. Defaults to 0.9.
53
+ eigval_bounds (tuple, optional):
54
+ eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
55
+ params_beta (float | None, optional):
56
+ if not None, EMA of parameters is used for preconditioner update. Defaults to None.
57
+ grads_beta (float | None, optional):
58
+ if not None, EMA of gradients is used for preconditioner update. Defaults to None.
59
+ update_freq (int, optional):
60
+ how often to update L-BFGS history. Defaults to 1.
61
+ z_beta (float | None, optional):
62
+ optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
63
+ inner (Chainable | None, optional):
64
+ optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
65
+ """
66
+ def __init__(
67
+ self,
68
+ history_size=10,
69
+ sample_grads: Literal['before', 'after'] = 'before',
70
+ tol: float | None = 1e-10,
71
+ damping: bool = False,
72
+ init_damping=0.9,
73
+ eigval_bounds=(0.5, 50),
74
+ z_beta: float | None = None,
75
+ inner: Chainable | None = None,
76
+ ):
77
+ defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
78
+ super().__init__(defaults)
79
+
80
+ self.global_state['s_history'] = deque(maxlen=history_size)
81
+ self.global_state['y_history'] = deque(maxlen=history_size)
82
+ self.global_state['sy_history'] = deque(maxlen=history_size)
83
+
84
+ if inner is not None:
85
+ self.set_child('inner', inner)
86
+
87
+ def reset(self):
88
+ """Resets the internal state of the L-SR1 module."""
89
+ # super().reset() # Clears self.state (per-parameter) if any, and "step"
90
+ # Re-initialize L-SR1 specific global state
91
+ self.state.clear()
92
+ self.global_state['step'] = 0
93
+ self.global_state['s_history'].clear()
94
+ self.global_state['y_history'].clear()
95
+ self.global_state['sy_history'].clear()
96
+
97
+ @torch.no_grad
98
+ def step(self, vars):
99
+ assert vars.closure is not None
100
+
101
+ params = as_tensorlist(vars.params)
102
+ update = as_tensorlist(vars.get_update())
103
+ step = self.global_state.get('step', 0)
104
+ self.global_state['step'] = step + 1
105
+
106
+ # history of s and k
107
+ s_history: deque[TensorList] = self.global_state['s_history']
108
+ y_history: deque[TensorList] = self.global_state['y_history']
109
+ sy_history: deque[torch.Tensor] = self.global_state['sy_history']
110
+
111
+ tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
112
+ 'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
113
+
114
+ # sample gradient at previous params with current mini-batch
115
+ if sample_grads == 'before':
116
+ prev_params = self.get_state('prev_params', params=params, cls=TensorList)
117
+ if step == 0:
118
+ s_k = None; y_k = None; ys_k = None
119
+ else:
120
+ s_k = params - prev_params
121
+
122
+ current_params = params.clone()
123
+ params.set_(prev_params)
124
+ with torch.enable_grad(): vars.closure()
125
+ y_k = update - params.grad
126
+ ys_k = s_k.dot(y_k)
127
+ params.set_(current_params)
128
+
129
+ if damping:
130
+ s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
131
+
132
+ if ys_k > 1e-10:
133
+ s_history.append(s_k)
134
+ y_history.append(y_k)
135
+ sy_history.append(ys_k)
136
+
137
+ prev_params.copy_(params)
138
+
139
+ # use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
140
+ elif sample_grads == 'after':
141
+ if len(s_history) == 0:
142
+ s_k = None; y_k = None; ys_k = None
143
+ else:
144
+ s_k = s_history[-1]
145
+ y_k = y_history[-1]
146
+ ys_k = s_k.dot(y_k)
147
+
148
+ # this will run after params are updated by Modular after running all future modules
149
+ vars.post_step_hooks.append(
150
+ partial(
151
+ _store_sk_yk_after_step_hook,
152
+ prev_params=params.clone(),
153
+ prev_grad=update.clone(),
154
+ damping=damping,
155
+ init_damping=init_damping,
156
+ eigval_bounds=eigval_bounds,
157
+ s_history=s_history,
158
+ y_history=y_history,
159
+ sy_history=sy_history,
160
+ ))
161
+
162
+ else:
163
+ raise ValueError(sample_grads)
164
+
165
+ # step with inner module before applying preconditioner
166
+ if self.children:
167
+ update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
168
+
169
+ # tolerance on gradient difference to avoid exploding after converging
170
+ if tol is not None:
171
+ if y_k is not None and y_k.abs().global_max() <= tol:
172
+ vars.update = update # may have been updated by inner module, probably makes sense to use it here?
173
+ return vars
174
+
175
+ # lerp initial H^-1 @ q guess
176
+ z_ema = None
177
+ if z_beta is not None:
178
+ z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
179
+
180
+ # precondition
181
+ dir = lbfgs(
182
+ tensors_=as_tensorlist(update),
183
+ s_history=s_history,
184
+ y_history=y_history,
185
+ sy_history=sy_history,
186
+ y_k=y_k,
187
+ ys_k=ys_k,
188
+ z_beta = z_beta,
189
+ z_ema = z_ema,
190
+ step=step
191
+ )
192
+
193
+ vars.update = dir
194
+
195
+ return vars
196
+
@@ -0,0 +1,475 @@
1
+ """Use BFGS or maybe SR1."""
2
+ from typing import Any, Literal
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Mapping
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Preconditioner, TensorwisePreconditioner
8
+ from ...utils import TensorList, set_storage_
9
+
10
+ def _safe_dict_update_(d1_:dict, d2:dict):
11
+ inter = set(d1_.keys()).intersection(d2.keys())
12
+ if len(inter) > 0: raise RuntimeError(f"Duplicate keys {inter}")
13
+ d1_.update(d2)
14
+
15
+ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
16
+ if (beta is None) or (beta == 0) or (key not in state): state[key] = value
17
+ elif state[key].shape != value.shape: state[key] = value
18
+ else: state[key].lerp_(value, 1-beta)
19
+
20
+ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
21
+ def __init__(
22
+ self,
23
+ defaults: dict | None = None,
24
+ init_scale: float | Literal["auto"] = "auto",
25
+ tol: float = 1e-10,
26
+ tol_reset: bool = True,
27
+ reset_interval: int | None = None,
28
+ beta: float | None = None,
29
+ update_freq: int = 1,
30
+ scale_first: bool = True,
31
+ scale_second: bool = False,
32
+ concat_params: bool = True,
33
+ inverse: bool = True,
34
+ inner: Chainable | None = None,
35
+ ):
36
+ if defaults is None: defaults = {}
37
+ _safe_dict_update_(defaults, dict(init_scale=init_scale, tol=tol, tol_reset=tol_reset, scale_second=scale_second, inverse=inverse, beta=beta, reset_interval=reset_interval))
38
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, scale_first=scale_first, inner=inner)
39
+
40
+ def _get_init_scale(self,s:torch.Tensor,y:torch.Tensor) -> torch.Tensor | float:
41
+ """returns multiplier to H or B"""
42
+ ys = y.dot(s)
43
+ yy = y.dot(y)
44
+ if ys != 0 and yy != 0: return yy/ys
45
+ return 1
46
+
47
+ def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
48
+ set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
49
+ if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
50
+ if init_scale >= 1:
51
+ if inverse: M /= init_scale
52
+ else: M *= init_scale
53
+
54
+ def update_H(self, H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
55
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
56
+ """update hessian inverse"""
57
+ raise NotImplementedError
58
+
59
+ def update_B(self, B:torch.Tensor, s:torch.Tensor, y:torch.Tensor, p:torch.Tensor, g:torch.Tensor,
60
+ p_prev:torch.Tensor, g_prev:torch.Tensor, state: dict[str, Any], settings: Mapping[str, Any]) -> torch.Tensor:
61
+ """update hessian"""
62
+ raise NotImplementedError
63
+
64
+ @torch.no_grad
65
+ def update_tensor(self, tensor, param, grad, state, settings):
66
+ p = param.view(-1); g = tensor.view(-1)
67
+ inverse = settings['inverse']
68
+ M_key = 'H' if inverse else 'B'
69
+ M = state.get(M_key, None)
70
+ step = state.get('step', 0)
71
+ init_scale = settings['init_scale']
72
+ tol = settings['tol']
73
+ tol_reset = settings['tol_reset']
74
+ reset_interval = settings['reset_interval']
75
+
76
+ if M is None:
77
+ M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
78
+ if isinstance(init_scale, (int, float)) and init_scale != 1:
79
+ if inverse: M /= init_scale
80
+ else: M *= init_scale
81
+
82
+ state[M_key] = M
83
+ state['p_prev'] = p.clone()
84
+ state['g_prev'] = g.clone()
85
+ return
86
+
87
+ p_prev = state['p_prev']
88
+ g_prev = state['g_prev']
89
+ s: torch.Tensor = p - p_prev
90
+ y: torch.Tensor = g - g_prev
91
+ state['p_prev'].copy_(p)
92
+ state['g_prev'].copy_(g)
93
+
94
+
95
+ if reset_interval is not None and step % reset_interval == 0:
96
+ self._reset_M_(M, s, y, inverse, init_scale)
97
+ return
98
+
99
+ # tolerance on gradient difference to avoid exploding after converging
100
+ if y.abs().max() <= tol:
101
+ # reset history
102
+ if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
103
+ return
104
+
105
+ if step == 1 and init_scale == 'auto':
106
+ if inverse: M /= self._get_init_scale(s,y)
107
+ else: M *= self._get_init_scale(s,y)
108
+
109
+ beta = settings['beta']
110
+ if beta is not None and beta != 0: M = M.clone() # because all of them update it in-place
111
+
112
+ if inverse:
113
+ H_new = self.update_H(H=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
114
+ _maybe_lerp_(state, 'H', H_new, beta)
115
+
116
+ else:
117
+ B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
118
+ _maybe_lerp_(state, 'B', B_new, beta)
119
+
120
+ @torch.no_grad
121
+ def apply_tensor(self, tensor, param, grad, state, settings):
122
+ step = state['step'] = state.get('step', 0) + 1
123
+
124
+ if settings['scale_second'] and step == 2:
125
+ s = max(1, tensor.abs().sum()) # pyright:ignore[reportArgumentType]
126
+ if s < settings['tol']: tensor = tensor/s
127
+
128
+ inverse = settings['inverse']
129
+ if inverse:
130
+ H = state['H']
131
+ return (H @ tensor.view(-1)).view_as(tensor)
132
+
133
+ B = state['B']
134
+
135
+ return torch.linalg.solve_ex(B, tensor.view(-1))[0].view_as(tensor) # pylint:disable=not-callable
136
+
137
+ # to avoid typing all arguments for each method
138
+ class QuasiNewtonH(HessianUpdateStrategy):
139
+ def __init__(
140
+ self,
141
+ init_scale: float | Literal["auto"] = "auto",
142
+ tol: float = 1e-10,
143
+ tol_reset: bool = True,
144
+ reset_interval: int | None = None,
145
+ beta: float | None = None,
146
+ update_freq: int = 1,
147
+ scale_first: bool = True,
148
+ scale_second: bool = False,
149
+ concat_params: bool = True,
150
+ inner: Chainable | None = None,
151
+ ):
152
+ super().__init__(
153
+ defaults=None,
154
+ init_scale=init_scale,
155
+ tol=tol,
156
+ tol_reset=tol_reset,
157
+ reset_interval=reset_interval,
158
+ beta=beta,
159
+ update_freq=update_freq,
160
+ scale_first=scale_first,
161
+ scale_second=scale_second,
162
+ concat_params=concat_params,
163
+ inverse=True,
164
+ inner=inner,
165
+ )
166
+ # ----------------------------------- BFGS ----------------------------------- #
167
+ def bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
168
+ sy = torch.dot(s, y)
169
+ if sy <= tol: return H # don't reset H in this case
170
+ num1 = (sy + (y @ H @ y)) * s.outer(s)
171
+ term1 = num1.div_(sy**2)
172
+ num2 = (torch.outer(H @ y, s).add_(torch.outer(s, y) @ H))
173
+ term2 = num2.div_(sy)
174
+ H += term1.sub_(term2)
175
+ return H
176
+
177
+ class BFGS(QuasiNewtonH):
178
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
179
+ return bfgs_H_(H=H, s=s, y=y, tol=settings['tol'])
180
+
181
+ # ------------------------------------ SR1 ----------------------------------- #
182
+ def sr1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
183
+ z = s - H@y
184
+ denom = torch.dot(z, y)
185
+
186
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
187
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
188
+
189
+ if y_norm*z_norm < tol: return H
190
+
191
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
192
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
193
+ H += torch.outer(z, z).div_(denom)
194
+ return H
195
+
196
+ class SR1(QuasiNewtonH):
197
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
198
+ return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
199
+
200
+ # BFGS has defaults - init_scale = "auto" and scale_second = False
201
+ # SR1 has defaults - init_scale = 1 and scale_second = True
202
+ # basically some methods work better with first and some with second.
203
+ # I inherit from BFGS or SR1 to avoid writing all those arguments again
204
+ # ------------------------------------ DFP ----------------------------------- #
205
+ def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
206
+ sy = torch.dot(s, y)
207
+ if sy.abs() <= tol: return H
208
+ term1 = torch.outer(s, s).div_(sy)
209
+ denom = torch.dot(y, H @ y) #
210
+ if denom.abs() <= tol: return H
211
+ num = H @ torch.outer(y, y) @ H
212
+ term2 = num.div_(denom)
213
+ H += term1.sub_(term2)
214
+ return H
215
+
216
+ class DFP(QuasiNewtonH):
217
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
218
+ return dfp_H_(H=H, s=s, y=y, tol=settings['tol'])
219
+
220
+
221
+ # formulas for methods below from Spedicato, E., & Huang, Z. (1997). Numerical experience with newton-like methods for nonlinear algebraic systems. Computing, 58(1), 69–89. doi:10.1007/bf02684472
222
+ # H' = H - (Hy - S)c^T / c^T*y
223
+ # the difference is how `c` is calculated
224
+
225
+ def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
226
+ c = H.T @ s
227
+ denom = c.dot(y)
228
+ if denom.abs() <= tol: return H
229
+ num = (H@y).sub_(s).outer(c)
230
+ H -= num/denom
231
+ return H
232
+
233
+ def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
234
+ c = y
235
+ denom = c.dot(y)
236
+ if denom.abs() <= tol: return H
237
+ num = (H@y).sub_(s).outer(c)
238
+ H -= num/denom
239
+ return H
240
+
241
+ def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
242
+ c = g_prev
243
+ denom = c.dot(y)
244
+ if denom.abs() <= tol: return H
245
+ num = (H@y).sub_(s).outer(c)
246
+ H -= num/denom
247
+ return H
248
+
249
+ def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
250
+ c = torch.linalg.multi_dot([H,H,y]) # pylint:disable=not-callable
251
+ denom = c.dot(y)
252
+ if denom.abs() <= tol: return H
253
+ num = (H@y).sub_(s).outer(c)
254
+ H -= num/denom
255
+ return H
256
+
257
+ class BroydenGood(QuasiNewtonH):
258
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
259
+ return broyden_good_H_(H=H, s=s, y=y, tol=settings['tol'])
260
+
261
+ class BroydenBad(QuasiNewtonH):
262
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
263
+ return broyden_bad_H_(H=H, s=s, y=y, tol=settings['tol'])
264
+
265
+ class Greenstadt1(QuasiNewtonH):
266
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
267
+ return greenstadt1_H_(H=H, s=s, y=y, g_prev=g_prev, tol=settings['tol'])
268
+
269
+ class Greenstadt2(QuasiNewtonH):
270
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
271
+ return greenstadt2_H_(H=H, s=s, y=y, tol=settings['tol'])
272
+
273
+
274
+ def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
275
+ n = H.shape[0]
276
+
277
+ j = y.abs().argmax()
278
+ u = torch.zeros(n, device=H.device, dtype=H.dtype)
279
+ u[j] = 1.0
280
+
281
+ denom = y[j]
282
+ if denom.abs() < tol: return H
283
+
284
+ Hy = H @ y.unsqueeze(1)
285
+ num = s.unsqueeze(1) - Hy
286
+
287
+ H[:, j] += num.squeeze() / denom
288
+ return H
289
+
290
+ class ColumnUpdatingMethod(QuasiNewtonH):
291
+ """Lopes, V. L., & Martínez, J. M. (1995). Convergence properties of the inverse column-updating method. Optimization Methods & Software, 6(2), 127–144. from https://www.ime.unicamp.br/sites/default/files/pesquisa/relatorios/rp-1993-76.pdf"""
292
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
293
+ return column_updating_H_(H=H, s=s, y=y, tol=settings['tol'])
294
+
295
+ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
296
+ s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
297
+ I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
298
+ d = (R + I * (s_norm/2)) @ s
299
+ denom = d.dot(s)
300
+ if denom.abs() <= tol: return H, R
301
+ R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
302
+
303
+ c = H.T @ d
304
+ denom = c.dot(y)
305
+ if denom.abs() <= tol: return H, R
306
+ num = (H@y).sub_(s).outer(c)
307
+ H -= num/denom
308
+ return H, R
309
+
310
+ class ThomasOptimalMethod(QuasiNewtonH):
311
+ """Thomas, Stephen Walter. Sequential estimation techniques for quasi-Newton algorithms. Cornell University, 1975."""
312
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
313
+ if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
314
+ H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
315
+ return H
316
+
317
+ # ------------------------ powell's symmetric broyden ------------------------ #
318
+ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
319
+ y_Bs = y - B@s
320
+ ss = s.dot(s)
321
+ if ss.abs() < tol: return B
322
+ num1 = y_Bs.outer(s).add_(s.outer(y_Bs))
323
+ term1 = num1.div_(ss)
324
+ term2 = s.outer(s).mul_(y_Bs.dot(s)/(ss**2))
325
+ B += term1.sub_(term2)
326
+ return B
327
+
328
+ class PSB(HessianUpdateStrategy):
329
+ def __init__(
330
+ self,
331
+ init_scale: float | Literal["auto"] = 'auto',
332
+ tol: float = 1e-10,
333
+ tol_reset: bool = True,
334
+ reset_interval: int | None = None,
335
+ beta: float | None = None,
336
+ update_freq: int = 1,
337
+ scale_first: bool = True,
338
+ scale_second: bool = False,
339
+ concat_params: bool = True,
340
+ inner: Chainable | None = None,
341
+ ):
342
+ super().__init__(
343
+ defaults=None,
344
+ init_scale=init_scale,
345
+ tol=tol,
346
+ tol_reset=tol_reset,
347
+ reset_interval=reset_interval,
348
+ beta=beta,
349
+ update_freq=update_freq,
350
+ scale_first=scale_first,
351
+ scale_second=scale_second,
352
+ concat_params=concat_params,
353
+ inverse=False,
354
+ inner=inner,
355
+ )
356
+
357
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
358
+ return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
359
+
360
+ def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
361
+ sy = s.dot(y)
362
+ if sy.abs() <= tol: return H
363
+ num = (s - H@y).outer(s)
364
+ H += num.div_(sy)
365
+ return H
366
+
367
+ class Pearson2(QuasiNewtonH):
368
+ """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
369
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
370
+ return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
371
+
372
+ # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
373
+ def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
374
+ # in notation p is s, q is y, H is D
375
+ # another p is lr
376
+ # omega (o) = sy
377
+ # tau (t) = yHy
378
+ # epsilon = p'D^-1 p
379
+ # however p.12 says eps = gs / gHy
380
+
381
+ Hy = H@y
382
+ gHy = g.dot(Hy)
383
+ yHy = y.dot(Hy)
384
+ sy = s.dot(y)
385
+ if sy < tol: return H
386
+ if yHy.abs() < tol: return H
387
+ if gHy.abs() < tol: return H
388
+
389
+ v_mul = yHy.sqrt()
390
+ v_term1 = s/sy
391
+ v_term2 = Hy/yHy
392
+ v = (v_term1.sub_(v_term2)).mul_(v_mul)
393
+ gs = g.dot(s)
394
+
395
+ if isinstance(switch, tuple): phi, theta = switch
396
+ else:
397
+ o = sy
398
+ t = yHy
399
+ e = gs / gHy
400
+ if switch in (1, 3):
401
+ if e/o <= 1:
402
+ if o.abs() <= tol: return H
403
+ phi = e/o
404
+ theta = 0
405
+ elif o/t >= 1:
406
+ if t.abs() <= tol: return H
407
+ phi = o/t
408
+ theta = 1
409
+ else:
410
+ phi = 1
411
+ denom = e*t - o**2
412
+ if denom.abs() <= tol: return H
413
+ if switch == 1: theta = o * (e - o) / denom
414
+ else: theta = o * (t - o) / denom
415
+
416
+ elif switch == 2:
417
+ if t.abs() <= tol or o.abs() <= tol or e.abs() <= tol: return H
418
+ phi = (e / t) ** 0.5
419
+ theta = 1 / (1 + (t*e / o**2)**0.5)
420
+
421
+ elif switch == 4:
422
+ if t.abs() <= tol: return H
423
+ phi = e/t
424
+ theta = 1/2
425
+
426
+ else: raise ValueError(switch)
427
+
428
+
429
+ u = phi * (gs/gHy) + (1 - phi) * (sy/yHy)
430
+ term1 = (H @ y.outer(y) @ H).div_(yHy)
431
+ term2 = v.outer(v).mul_(theta)
432
+ term3 = s.outer(s).div_(sy)
433
+
434
+ H -= term1
435
+ H += term2
436
+ H *= u
437
+ H += term3
438
+ return H
439
+
440
+
441
+ class SSVM(HessianUpdateStrategy):
442
+ """This one is from Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable Metric algorithms. Mathematical Programming, 10(1), 70–90. doi:10.1007/bf01580654
443
+ """
444
+ def __init__(
445
+ self,
446
+ switch: tuple[float,float] | Literal[1,2,3,4] = 3,
447
+ init_scale: float | Literal["auto"] = 'auto',
448
+ tol: float = 1e-10,
449
+ tol_reset: bool = True,
450
+ reset_interval: int | None = None,
451
+ beta: float | None = None,
452
+ update_freq: int = 1,
453
+ scale_first: bool = True,
454
+ scale_second: bool = False,
455
+ concat_params: bool = True,
456
+ inner: Chainable | None = None,
457
+ ):
458
+ defaults = dict(switch=switch)
459
+ super().__init__(
460
+ defaults=defaults,
461
+ init_scale=init_scale,
462
+ tol=tol,
463
+ tol_reset=tol_reset,
464
+ reset_interval=reset_interval,
465
+ beta=beta,
466
+ update_freq=update_freq,
467
+ scale_first=scale_first,
468
+ scale_second=scale_second,
469
+ concat_params=concat_params,
470
+ inverse=True,
471
+ inner=inner,
472
+ )
473
+
474
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
475
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
@@ -1,4 +1,3 @@
1
- r"""
2
- This includes modules that use the hessian computed via autograd.
3
- """
4
- from .newton import ExactNewton, LinearSystemSolvers, FallbackLinearSystemSolvers, LINEAR_SYSTEM_SOLVERS
1
+ from .newton import Newton
2
+ from .newton_cg import NewtonCG
3
+ from .nystrom import NystromSketchAndSolve, NystromPCG