torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,225 +0,0 @@
1
- import itertools
2
- import math
3
- import warnings
4
- from collections.abc import Callable
5
- from contextlib import nullcontext
6
- from functools import partial
7
- from typing import Any, Literal
8
-
9
- import numpy as np
10
- import scipy.optimize
11
- import torch
12
-
13
- from ...core import Chainable, Module, apply_transform
14
- from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
- from ...utils.derivatives import (
16
- hessian_list_to_mat,
17
- jacobian_wrt,
18
- hvp,
19
- )
20
-
21
- def _poly_eval_diag(s: np.ndarray, c, derivatives):
22
- val = float(c) + (derivatives[0] * s).sum(-1)
23
-
24
- if len(derivatives) > 1:
25
- for i, d_diag in enumerate(derivatives[1:], 2):
26
- val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
27
-
28
- return val
29
-
30
- def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
31
- """Computes the value of the proximal polynomial approximation."""
32
- if x.ndim == 2: x = x.T
33
- s = x - x0
34
-
35
- val = _poly_eval_diag(s, c, derivatives)
36
-
37
- penalty = 0
38
- if prox != 0:
39
- penalty = (prox / 2) * (s**2).sum(-1)
40
-
41
- return val + penalty
42
-
43
- def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
44
- """Computes the gradient of the proximal polynomial approximation."""
45
- s = x - x0
46
-
47
- g = derivatives[0].copy()
48
-
49
- if len(derivatives) > 1:
50
- for i, d_diag in enumerate(derivatives[1:], 2):
51
- g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
52
-
53
- if prox != 0:
54
- g += prox * s
55
-
56
- return g
57
-
58
- def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
59
- """Computes the Hessian of the proximal polynomial approximation."""
60
- s = x - x0
61
- n = x.shape[0]
62
-
63
- if len(derivatives) < 2:
64
- H_diag = np.zeros(n, dtype=s.dtype)
65
- else:
66
- H_diag = derivatives[1].copy()
67
-
68
- if len(derivatives) > 2:
69
- for i, d_diag in enumerate(derivatives[2:], 3):
70
- H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
71
-
72
- if prox != 0:
73
- H_diag += prox
74
-
75
- return np.diag(H_diag)
76
-
77
- def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
78
- derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
79
- x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
80
- bounds = None
81
- if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
82
-
83
- # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
84
- if bounds is None:
85
- if len(derivatives) == 1: method = 'bfgs'
86
- else: method = 'trust-exact'
87
- else:
88
- if len(derivatives) == 1: method = 'l-bfgs-b'
89
- else: method = 'trust-constr'
90
-
91
- x_init = x0.copy()
92
- v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
93
- if de_iters is not None and de_iters != 0:
94
- if de_iters == -1: de_iters = None # let scipy decide
95
- res = scipy.optimize.differential_evolution(
96
- _proximal_poly_v_diag,
97
- bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
98
- args=(c, prox, x0.copy(), derivatives),
99
- maxiter=de_iters,
100
- vectorized=True,
101
- )
102
- if res.fun < v0: x_init = res.x
103
-
104
- res = scipy.optimize.minimize(
105
- _proximal_poly_v_diag,
106
- x_init,
107
- method=method,
108
- args=(c, prox, x0.copy(), derivatives),
109
- jac=_proximal_poly_g_diag,
110
- hess=_proximal_poly_H_diag,
111
- bounds=bounds
112
- )
113
-
114
- return torch.from_numpy(res.x).to(x), res.fun
115
-
116
-
117
-
118
- class DiagonalHigherOrderNewton(Module):
119
- """
120
- Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
121
- except it doesn't work in all cases except ones where it works.
122
- """
123
- def __init__(
124
- self,
125
- order: int = 4,
126
- trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
127
- increase: float = 1.5,
128
- decrease: float = 0.75,
129
- trust_init: float | None = None,
130
- trust_tol: float = 1,
131
- de_iters: int | None = None,
132
- vectorize: bool = True,
133
- ):
134
- if trust_init is None:
135
- if trust_method == 'bounds': trust_init = 1
136
- else: trust_init = 0.1
137
-
138
- defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
139
- super().__init__(defaults)
140
-
141
- @torch.no_grad
142
- def step(self, var):
143
- params = TensorList(var.params)
144
- closure = var.closure
145
- if closure is None: raise RuntimeError('NewtonCG requires closure')
146
-
147
- settings = self.settings[params[0]]
148
- order = settings['order']
149
- increase = settings['increase']
150
- decrease = settings['decrease']
151
- trust_tol = settings['trust_tol']
152
- trust_init = settings['trust_init']
153
- trust_method = settings['trust_method']
154
- de_iters = settings['de_iters']
155
-
156
- trust_value = self.global_state.get('trust_value', trust_init)
157
-
158
-
159
- # ------------------------ calculate grad and hessian ------------------------ #
160
- with torch.enable_grad():
161
- loss = var.loss = var.loss_approx = closure(False)
162
-
163
- g = torch.autograd.grad(loss, params, create_graph=True)
164
- var.grad = list(g)
165
-
166
- derivatives = [g]
167
- T = g # current derivatives tensor diagonal
168
- ones = [torch.ones_like(t) for t in g]
169
-
170
- # get all derivatives up to order
171
- for o in range(2, order + 1):
172
- T = hvp(params, T, ones, create_graph=o != order)
173
- derivatives.append(T)
174
-
175
- x0 = torch.cat([p.ravel() for p in params])
176
-
177
- if trust_method is None: trust_method = 'none'
178
- else: trust_method = trust_method.lower()
179
-
180
- if trust_method == 'none':
181
- trust_region = None
182
- prox = 0
183
-
184
- elif trust_method == 'bounds':
185
- trust_region = trust_value
186
- prox = 0
187
-
188
- elif trust_method == 'proximal':
189
- trust_region = None
190
- prox = 1 / trust_value
191
-
192
- else:
193
- raise ValueError(trust_method)
194
-
195
- x_star, expected_loss = _poly_minimize(
196
- trust_region=trust_region,
197
- prox=prox,
198
- de_iters=de_iters,
199
- c=loss.item(),
200
- x=x0,
201
- derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
202
- )
203
-
204
- # trust region
205
- if trust_method != 'none':
206
- expected_reduction = loss - expected_loss
207
-
208
- vec_to_tensors_(x_star, params)
209
- loss_star = closure(False)
210
- vec_to_tensors_(x0, params)
211
- reduction = loss - loss_star
212
-
213
- # failed step
214
- if reduction <= 0:
215
- x_star = x0
216
- self.global_state['trust_value'] = trust_value * decrease
217
-
218
- # very good step
219
- elif expected_reduction / reduction <= trust_tol:
220
- self.global_state['trust_value'] = trust_value * increase
221
-
222
- difference = vec_to_tensors(x0 - x_star, params)
223
- var.update = list(difference)
224
- return var
225
-
@@ -1,117 +0,0 @@
1
- from contextlib import nullcontext
2
- import warnings
3
- from collections.abc import Callable
4
- from functools import partial
5
- import itertools
6
- from typing import Literal
7
-
8
- import torch
9
-
10
- from ...core import Chainable, Module, apply_transform
11
- from ...utils import TensorList, vec_to_tensors
12
- from ...utils.derivatives import (
13
- hessian_list_to_mat,
14
- jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
15
- )
16
-
17
- def _batched_dot(x, y):
18
- return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
19
-
20
- def _cosine_similarity(x, y):
21
- denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
22
- return _batched_dot(x, y) / denom
23
-
24
- class EigenDescent(Module):
25
- """
26
- Uses eigenvectors corresponding to certain eigenvalues. Please note that this is experimental and isn't guaranteed to work.
27
-
28
- Args:
29
- mode (str, optional):
30
- - largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
31
- - smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
32
- - mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
33
- - mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
34
- - mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
35
- - mm - for testing.
36
-
37
- Defaults to 'mean-sign'.
38
- hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
39
- vectorize (bool, optional): how to calculate hessian. Defaults to True.
40
-
41
- """
42
- def __init__(
43
- self,
44
- mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
45
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
46
- vectorize: bool = True,
47
- ):
48
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
49
- super().__init__(defaults)
50
-
51
- @torch.no_grad
52
- def step(self, var):
53
- params = TensorList(var.params)
54
- closure = var.closure
55
- if closure is None: raise RuntimeError('NewtonCG requires closure')
56
-
57
- settings = self.settings[params[0]]
58
- mode = settings['mode']
59
- hessian_method = settings['hessian_method']
60
- vectorize = settings['vectorize']
61
-
62
- # ------------------------ calculate grad and hessian ------------------------ #
63
- if hessian_method == 'autograd':
64
- with torch.enable_grad():
65
- loss = var.loss = var.loss_approx = closure(False)
66
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
67
- g_list = [t[0] for t in g_list] # remove leading dim from loss
68
- var.grad = g_list
69
- H = hessian_list_to_mat(H_list)
70
-
71
- elif hessian_method in ('func', 'autograd.functional'):
72
- strat = 'forward-mode' if vectorize else 'reverse-mode'
73
- with torch.enable_grad():
74
- g_list = var.get_grad(retain_graph=True)
75
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
76
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
77
-
78
- else:
79
- raise ValueError(hessian_method)
80
-
81
-
82
- # ----------------------------------- solve ---------------------------------- #
83
- g = torch.cat([t.ravel() for t in g_list])
84
- L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
85
- if mode == 'largest':
86
- # smallest eigenvalue if all eigenvalues are negative else largest
87
- if L[-1] <= 0: d = Q[0]
88
- else: d = Q[-1]
89
-
90
- elif mode == 'smallest':
91
- # smallest eigenvalue if negative eigenvalues exist else largest
92
- if L[0] <= 0: d = Q[0]
93
- else: d = Q[-1]
94
-
95
- elif mode == 'magnitude':
96
- # largest by magnitude
97
- if L[0].abs() > L[-1].abs(): d = Q[0]
98
- else: d = Q[-1]
99
-
100
- elif mode == 'mean-dot':
101
- d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
102
-
103
- elif mode == 'mean-sign':
104
- d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
105
-
106
- elif mode == 'mean-cosine':
107
- d = (Q * _cosine_similarity(Q, g)).mean(1)
108
-
109
- elif mode == 'mm':
110
- d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
111
-
112
- else:
113
- raise ValueError(mode)
114
-
115
- var.update = vec_to_tensors(g.dot(d).sign() * d, params)
116
- return var
117
-
@@ -1,172 +0,0 @@
1
- from typing import cast
2
- import warnings
3
-
4
- import torch
5
-
6
- from ...core import Module
7
- from ...utils import vec_to_tensors, vec_to_tensors_
8
-
9
-
10
- class ExponentialTrajectoryFit(Module):
11
- """A method. Please note that this is experimental and isn't guaranteed to work."""
12
- def __init__(self, step_size=1e-3):
13
- defaults = dict(step_size = step_size)
14
- super().__init__(defaults)
15
-
16
- @torch.no_grad
17
- def step(self, var):
18
- closure = var.closure
19
- assert closure is not None
20
- step_size = self.settings[var.params[0]]['step_size']
21
-
22
- # 1. perform 3 GD steps to obtain 4 points
23
- points = [torch.cat([p.view(-1) for p in var.params])]
24
- for i in range(3):
25
- if i == 0: grad = var.get_grad()
26
- else:
27
- with torch.enable_grad(): closure()
28
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
29
-
30
- # GD step
31
- torch._foreach_sub_(var.params, grad, alpha=step_size)
32
-
33
- points.append(torch.cat([p.view(-1) for p in var.params]))
34
-
35
- assert len(points) == 4, len(points)
36
- x0, x1, x2, x3 = points
37
- dim = x0.numel()
38
-
39
- # 2. fit a generalized exponential curve
40
- d0 = (x1 - x0).unsqueeze(1) # column vectors
41
- d1 = (x2 - x1).unsqueeze(1)
42
- d2 = (x3 - x2).unsqueeze(1)
43
-
44
- # cat
45
- D1 = torch.cat([d0, d1], dim=1)
46
- D2 = torch.cat([d1, d2], dim=1)
47
-
48
- # if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
49
- if x0.numel() >= 2:
50
- if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
51
- pass # need to put a quadratic fit there
52
-
53
- M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
54
-
55
- # now we can predict x*
56
- I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
57
- B = I - M
58
- z = x1 - M @ x0
59
-
60
- x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
61
-
62
- vec_to_tensors_(x0, var.params)
63
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
64
- var.update = list(difference)
65
- return var
66
-
67
-
68
-
69
- class ExponentialTrajectoryFitV2(Module):
70
- """Should be better than one above, except it isn't. Please note that this is experimental and isn't guaranteed to work."""
71
- def __init__(self, step_size=1e-3, num_steps: int= 4):
72
- defaults = dict(step_size = step_size, num_steps=num_steps)
73
- super().__init__(defaults)
74
-
75
- @torch.no_grad
76
- def step(self, var):
77
- closure = var.closure
78
- assert closure is not None
79
- step_size = self.settings[var.params[0]]['step_size']
80
- num_steps = self.settings[var.params[0]]['num_steps']
81
-
82
- # 1. perform 3 GD steps to obtain 4 points (or more)
83
- grad = var.get_grad()
84
- points = [torch.cat([p.view(-1) for p in var.params])]
85
- point_grads = [torch.cat([g.view(-1) for g in grad])]
86
-
87
- for i in range(num_steps):
88
- # GD step
89
- torch._foreach_sub_(var.params, grad, alpha=step_size)
90
-
91
- points.append(torch.cat([p.view(-1) for p in var.params]))
92
-
93
- closure(backward=True)
94
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
95
- point_grads.append(torch.cat([g.view(-1) for g in grad]))
96
-
97
-
98
- X = torch.stack(points, 1) # dim, num_steps+1
99
- G = torch.stack(point_grads, 1)
100
- dim = points[0].numel()
101
-
102
- X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
103
-
104
- P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
105
- A = P[:, :dim]
106
- b = -P[:, dim]
107
-
108
- # symmetrize
109
- A = 0.5 * (A + A.T)
110
-
111
- # predict x*
112
- x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
113
-
114
- vec_to_tensors_(points[0], var.params)
115
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
116
- var.update = list(difference)
117
- return var
118
-
119
-
120
-
121
-
122
- def _fit_exponential(y0, y1, y2):
123
- """x0, x1 and x2 are assumed to be 0, 1, 2"""
124
- r = (y2 - y1) / (y1 - y0)
125
- ones = r==1
126
- r[ones] = 0
127
- B = (y1 - y0) / (r - 1)
128
- A = y0 - B
129
-
130
- A[ones] = 0
131
- B[ones] = 0
132
- return A, B, r
133
-
134
- class PointwiseExponential(Module):
135
- """A stupid method (for my youtube channel). Please note that this is experimental and isn't guaranteed to work."""
136
- def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
137
- defaults = dict(reg=reg, steps=steps, step_size=step_size)
138
- super().__init__(defaults)
139
-
140
- @torch.no_grad
141
- def step(self, var):
142
- closure = var.closure
143
- assert closure is not None
144
- settings = self.settings[var.params[0]]
145
- step_size = settings['step_size']
146
- reg = settings['reg']
147
- steps = settings['steps']
148
-
149
- # 1. perform 2 GD steps to obtain 3 points
150
- points = [torch.cat([p.view(-1) for p in var.params])]
151
- for i in range(2):
152
- if i == 0: grad = var.get_grad()
153
- else:
154
- with torch.enable_grad(): closure()
155
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
156
-
157
- # GD step
158
- torch._foreach_sub_(var.params, grad, alpha=step_size)
159
-
160
- points.append(torch.cat([p.view(-1) for p in var.params]))
161
-
162
- assert len(points) == 3, len(points)
163
- y0, y1, y2 = points
164
-
165
- A, B, r = _fit_exponential(y0, y1, y2)
166
- r = r.clip(max = 1-reg)
167
- x_star = A + B * r**steps
168
-
169
- vec_to_tensors_(y0, var.params)
170
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
171
- var.update = list(difference)
172
- return var
@@ -1,163 +0,0 @@
1
- from operator import itemgetter
2
-
3
- import torch
4
-
5
- from ...core import Chainable, Transform
6
- from ..optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
- from ..optimizers.soap import (
8
- update_soap_covariances_,
9
- get_orthogonal_matrix,
10
- get_orthogonal_matrix_QR,
11
- project,
12
- project_back,
13
- )
14
-
15
- class SOAPY(Transform):
16
- """Adam but uses scaled gradient differences for GGᵀ. Please note that this is experimental and isn't guaranteed to work.
17
-
18
- New args:
19
- scale_by_s - whether to scale gradient differences by parameter differences
20
- y_to_ema2 - whether to use gradient differences for exponential moving average too
21
- """
22
- def __init__(
23
- self,
24
- beta1: float = 0.95,
25
- beta2: float = 0.95,
26
- shampoo_beta: float | None = 0.95,
27
- precond_freq: int = 10,
28
- merge_small: bool = True,
29
- max_dim: int = 2_000,
30
- precondition_1d: bool = True,
31
- eps: float = 1e-8,
32
- decay: float | None = None,
33
- alpha: float = 1,
34
- bias_correction: bool = True,
35
- scale_by_s: bool = True,
36
- y_to_ema2: bool = False,
37
- ):
38
- defaults = dict(
39
- beta1=beta1,
40
- beta2=beta2,
41
- shampoo_beta=shampoo_beta,
42
- precond_freq=precond_freq,
43
- merge_small=merge_small,
44
- max_dim=max_dim,
45
- precondition_1d=precondition_1d,
46
- eps=eps,
47
- decay=decay,
48
- bias_correction=bias_correction,
49
- alpha=alpha,
50
- scale_by_s=scale_by_s,
51
- y_to_ema2=y_to_ema2,
52
- )
53
- super().__init__(defaults, uses_grad=False)
54
-
55
- @torch.no_grad
56
- def apply(self, tensors, params, grads, loss, states, settings):
57
- updates = []
58
- # update preconditioners
59
- for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
60
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, alpha = itemgetter(
61
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'alpha')(setting)
62
- scale_by_s = setting['scale_by_s']
63
- y_to_ema2 = setting['y_to_ema2']
64
-
65
- if merge_small:
66
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
67
-
68
- if 'g_prev' not in state:
69
- state['p_prev'] = p.clone()
70
- state['g_prev'] = t.clone()
71
- updates.append(tensors[i].clip(-0.1,0.1))
72
- continue
73
-
74
- p_prev = state['p_prev']
75
- g_prev = state['g_prev']
76
- s = p - p_prev
77
- y = t - g_prev
78
- if scale_by_s: y /= torch.linalg.norm(s).clip(min=1e-8) # pylint:disable=not-callable
79
-
80
- state['p_prev'].copy_(p)
81
- state['g_prev'].copy_(t)
82
-
83
- # initialize state on 1st step
84
- if 'GG' not in state:
85
- state["exp_avg"] = torch.zeros_like(t)
86
- if y_to_ema2: state["exp_avg_sq"] = torch.ones_like(t)
87
- else: state["exp_avg_sq"] = torch.zeros_like(t)
88
-
89
- if not precondition_1d and t.ndim <= 1:
90
- state['GG'] = []
91
-
92
- else:
93
- state['GG'] = [torch.zeros(sh, sh, dtype=t.dtype, device=t.device) if 1<sh<max_dim else None for sh in t.shape]
94
-
95
- # either scalar parameter, 1d with precondition_1d=False, or all dims are too big.
96
- if len([i is not None for i in state['GG']]) == 0:
97
- state['GG'] = None
98
-
99
- if state['GG'] is not None:
100
- update_soap_covariances_(y, GGs_=state['GG'], beta=shampoo_beta)
101
- state['Q'] = get_orthogonal_matrix(state['GG'])
102
-
103
- state['step'] = 0
104
- updates.append(tensors[i].clip(-0.1,0.1))
105
- continue # skip 1st step as in https://github.com/nikhilvyas/SOAP/blob/main/soap.py ?
106
- # I use sign instead as to not mess up with next modules. 1st Adam step is always sign anyway.
107
-
108
- # Projecting gradients to the eigenbases of Shampoo's preconditioner
109
- # i.e. projecting to the eigenbases of matrices in state['GG']
110
- z_projected = None
111
- if state['GG'] is not None:
112
- if y_to_ema2: z_projected = project(y, state['Q'])
113
- else: z_projected = project(t, state['Q'])
114
-
115
- # exponential moving averages
116
- # this part could be foreached but I will do that at some point its not a big difference compared to preconditioning
117
- exp_avg: torch.Tensor = state["exp_avg"]
118
- exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
119
-
120
- exp_avg.lerp_(t, 1-beta1)
121
-
122
- if z_projected is None:
123
- if y_to_ema2: exp_avg_sq.mul_(beta2).addcmul_(y, y, value=1-beta2)
124
- else: exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
125
- else:
126
- exp_avg_sq.mul_(beta2).addcmul_(z_projected, z_projected, value=1-beta2)
127
-
128
- # project exponential moving averages if they are accumulated unprojected
129
- exp_avg_projected = exp_avg
130
- if z_projected is not None:
131
- exp_avg_projected = project(exp_avg, state['Q'])
132
-
133
- exp_avg_sq_projected = exp_avg_sq
134
-
135
- denom = exp_avg_sq_projected.sqrt().add_(eps)
136
- # print(f'{t_projected = }, {exp_avg = }, {exp_avg_projected = }, {exp_avg_sq = }, {exp_avg_sq_projected = }, {denom = }')
137
-
138
- # Projecting back the preconditioned (by Adam) exponential moving average of gradients
139
- # to the original space
140
- update = exp_avg_projected / denom
141
- if z_projected is not None:
142
- update = project_back(update, state["Q"])
143
-
144
- if setting['bias_correction']:
145
- bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
146
- bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
147
- update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
148
- elif alpha is not None:
149
- update *= alpha
150
-
151
- if merge_small:
152
- update = _unmerge_small_dims(update, state['flat_sizes'], state['sort_idxs'])
153
-
154
- updates.append(update)
155
- state["step"] += 1
156
-
157
- # Update is done after the gradient step to avoid using current gradients in the projection.
158
- if state['GG'] is not None:
159
- update_soap_covariances_(y, state['GG'], shampoo_beta)
160
- if state['step'] % setting['precond_freq'] == 0:
161
- state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
162
-
163
- return updates