torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -2,181 +2,41 @@ from abc import ABC, abstractmethod
2
2
  import math
3
3
  from collections import deque
4
4
  from typing import Literal, Any
5
+ import itertools
5
6
 
6
7
  import torch
7
- from ...core import Chainable, TensorwisePreconditioner
8
+ from ...core import Chainable, TensorwiseTransform
8
9
  from ...utils.linalg.matrix_funcs import matrix_power_eigh
9
10
  from ...utils.linalg.svd import randomized_svd
10
11
  from ...utils.linalg.qr import qr_householder
11
12
 
13
+ def spectral_update(history, damping, rdamping, true_damping: bool):
14
+ M_hist = torch.stack(tuple(history), dim=1)
15
+ device = M_hist.device
16
+ M_hist = M_hist.cuda()
12
17
 
13
- class _Solver:
14
- @abstractmethod
15
- def update(self, history: deque[torch.Tensor], damping: float | None) -> tuple[Any, Any]:
16
- """returns stuff for apply"""
17
- @abstractmethod
18
- def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
19
- """apply preconditioning to tensor"""
18
+ try:
19
+ U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver='gesvda') # pylint:disable=not-callable
20
+ U = U.to(device); S = S.to(device)
20
21
 
21
- class _SVDSolver(_Solver):
22
- def __init__(self, driver=None): self.driver=driver
23
- def update(self, history, damping):
24
- M_hist = torch.stack(tuple(history), dim=1)
25
- device = None # driver is CUDA only
26
- if self.driver is not None:
27
- device = M_hist.device
28
- M_hist = M_hist.cuda()
22
+ if damping != 0 or rdamping != 0:
23
+ if rdamping != 0: rdamping *= torch.linalg.vector_norm(S) # pylint:disable=not-callable
24
+ Iu = damping + rdamping
25
+ if true_damping:
26
+ S.pow_(2)
27
+ Iu **= 2
28
+ S.add_(Iu)
29
+ if true_damping: S.sqrt_()
29
30
 
30
- try:
31
- U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
31
+ return U, 1/S
32
32
 
33
- if self.driver is not None:
34
- U = U.to(device); S = S.to(device)
33
+ except torch.linalg.LinAlgError:
34
+ return None, None
35
35
 
36
- if damping is not None and damping != 0: S.add_(damping)
37
- return U, S
36
+ def spectral_apply(g: torch.Tensor, U: torch.Tensor, S_inv: torch.Tensor):
37
+ Utg = (U.T @ g)*S_inv
38
+ return U @ Utg
38
39
 
39
- except torch.linalg.LinAlgError:
40
- return None, None
41
-
42
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
43
- Utg = (U.T @ g).div_(S)
44
- return U @ Utg
45
-
46
- class _SVDLowRankSolver(_Solver):
47
- def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
48
- def update(self, history, damping):
49
- M_hist = torch.stack(tuple(history), dim=1)
50
- try:
51
- U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
52
- if damping is not None and damping != 0: S.add_(damping)
53
- return U, S
54
- except torch.linalg.LinAlgError:
55
- return None, None
56
-
57
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
58
- Utg = (U.T @ g).div_(S)
59
- return U @ Utg
60
-
61
- class _RandomizedSVDSolver(_Solver):
62
- def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
63
- self.driver = driver
64
- self.k = k
65
-
66
- def update(self, history, damping):
67
- M_hist = torch.stack(tuple(history), dim=1)
68
- device = None # driver is CUDA only
69
- if self.driver is not None:
70
- device = M_hist.device
71
- M_hist = M_hist.cuda()
72
-
73
- try:
74
- U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
75
-
76
- if self.driver is not None:
77
- U = U.to(device); S = S.to(device)
78
-
79
- if damping is not None and damping != 0: S.add_(damping)
80
- return U, S
81
-
82
- except torch.linalg.LinAlgError:
83
- return None, None
84
-
85
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
86
- Utg = (U.T @ g).div_(S)
87
- return U @ Utg
88
-
89
- class _QRDiagonalSolver(_Solver):
90
- def __init__(self, sqrt=True): self.sqrt = sqrt
91
- def update(self, history, damping):
92
- M_hist = torch.stack(tuple(history), dim=1)
93
- try:
94
- Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
95
- R_diag = R.diag().abs()
96
- if damping is not None and damping != 0: R_diag.add_(damping)
97
- if self.sqrt: R_diag.sqrt_()
98
- return Q, R_diag
99
- except torch.linalg.LinAlgError:
100
- return None, None
101
-
102
- def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
103
- Qtg = (Q.T @ g).div_(R_diag)
104
- return Q @ Qtg
105
-
106
- class _QRSolver(_Solver):
107
- def __init__(self, sqrt=True): self.sqrt = sqrt
108
- def update(self, history, damping):
109
- M_hist = torch.stack(tuple(history), dim=1)
110
- try:
111
- # Q: d x k, R: k x k
112
- Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
113
- A = R @ R.T
114
- if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
115
- if self.sqrt: A = matrix_power_eigh(A, 0.5)
116
- return Q, A
117
- except (torch.linalg.LinAlgError):
118
- return None,None
119
-
120
- def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
121
- g_proj = Q.T @ g
122
- y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
123
- return Q @ y
124
-
125
- class _QRHouseholderSolver(_Solver):
126
- def __init__(self, sqrt=True): self.sqrt = sqrt
127
- def update(self, history, damping):
128
- M_hist = torch.stack(tuple(history), dim=1)
129
- try:
130
- # Q: d x k, R: k x k
131
- Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
132
- A = R @ R.T
133
- if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
134
- if self.sqrt: A = matrix_power_eigh(A, 0.5)
135
- return Q, A
136
- except (torch.linalg.LinAlgError):
137
- return None,None
138
-
139
- def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
140
- g_proj = Q.T @ g
141
- y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
142
- return Q @ y
143
-
144
-
145
- class _EighSolver(_Solver):
146
- def __init__(self, sqrt=True):
147
- self.sqrt = sqrt
148
-
149
- def update(self, history, damping):
150
- M_hist = torch.stack(tuple(history), dim=1)
151
- grams = M_hist @ M_hist.T # (d, d)
152
- if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
153
- try:
154
- L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
155
- L = L.abs().clamp_(min=1e-12)
156
- if self.sqrt: L = L.sqrt()
157
- return Q, L
158
- except torch.linalg.LinAlgError:
159
- return None, None
160
-
161
- def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
162
- Qtg = (Q.T @ g).div_(L)
163
- return Q @ Qtg
164
-
165
-
166
- SOLVERS = {
167
- "svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
168
- "svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
169
- "svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
170
- "svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
171
- "randomized_svd2": _RandomizedSVDSolver(2),
172
- "randomized_svd3": _RandomizedSVDSolver(3),
173
- "randomized_svd4": _RandomizedSVDSolver(4),
174
- "randomized_svd5": _RandomizedSVDSolver(5),
175
- "eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
176
- "qr": _QRSolver(),
177
- "qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
178
- "qrdiag": _QRDiagonalSolver(),
179
- }
180
40
 
181
41
  def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
182
42
  if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
@@ -184,63 +44,76 @@ def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
184
44
  if state_[key].shape != value.shape: state_[key] = value
185
45
  else: state_[key].lerp_(value, 1-beta)
186
46
 
187
- class SpectralPreconditioner(TensorwisePreconditioner):
188
- """Whitening preconditioner via SVD on history of past gradients or gradient differences scaled by parameter differences.
47
+ class SpectralPreconditioner(TensorwiseTransform):
48
+ """
49
+ The update rule is to stack recent gradients into M, compute U, S <- SVD(M), then calculate U (Uᵀg)/S.
50
+ This is equivalent to full matrix Adagrad with accumulator initialized to zeros,
51
+ except only recent :code:`history_size` gradients are used.
52
+ However this doesn't require N^2 memory and is computationally less expensive than Shampoo.
189
53
 
190
54
  Args:
191
- history_size (int, optional): number of past gradients to store for preconditioning. Defaults to 10.
192
- update_freq (int, optional): how often to re-compute the preconditioner. Defaults to 1.
193
- damping (float, optional): damping term, makes it closer to GD. Defaults to 1e-7.
55
+ history_size (int, optional): number of past gradients to store. Defaults to 10.
56
+ update_freq (int, optional): frequency of updating the preconditioner (U and S). Defaults to 1.
57
+ damping (float, optional): damping value. Defaults to 1e-4.
58
+ rdamping (float, optional): value of damping relative to singular values norm. Defaults to 0.
194
59
  order (int, optional):
195
- whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
- solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- A_beta (float | None, optional):
198
- beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
- B_beta (float | None, optional):
200
- beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
201
- interval (int, optional): How often to update history. Defaults to 1 (every step).
202
- concat_params (bool, optional):
203
- whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
204
- scale_first (bool, optional): makes first step small, usually not needed. Defaults to False.
205
- inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
60
+ order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
61
+ true_damping (bool, optional):
62
+ If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
63
+ U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
64
+ S_beta (float | None, optional): momentum for 1/S (too unstable, don't use). Defaults to None.
65
+ interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
66
+ concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to False.
67
+ normalize (bool, optional): whether to normalize gradients, this doesn't work well so don't use it. Defaults to False.
68
+ centralize (bool, optional): whether to centralize gradients, this doesn't work well so don't use it. Defaults to False.
69
+ inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
206
70
  """
71
+
207
72
  def __init__(
208
73
  self,
209
74
  history_size: int = 10,
210
75
  update_freq: int = 1,
211
- damping: float = 1e-12,
76
+ damping: float = 1e-4,
77
+ rdamping: float = 0,
212
78
  order: int = 1,
213
- solver: Literal['svd', 'svd_gesvdj', 'svd_gesvda', 'svd_lowrank', 'eigh', 'qr', 'qrdiag', 'qr_householder'] | _Solver | str = 'svd_gesvda',
214
- A_beta: float | None = None,
215
- B_beta: float | None = None,
79
+ true_damping: bool = True,
80
+ U_beta: float | None = None,
81
+ S_beta: float | None = None,
216
82
  interval: int = 1,
217
83
  concat_params: bool = False,
218
- scale_first: bool = False,
84
+ normalize: bool=False,
85
+ centralize:bool = False,
219
86
  inner: Chainable | None = None,
220
87
  ):
221
- if isinstance(solver, str): solver = SOLVERS[solver]
222
88
  # history is still updated each step so Precondition's update_freq has different meaning
223
- defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order, A_beta=A_beta, B_beta=B_beta, solver=solver)
224
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, scale_first=scale_first, inner=inner, update_freq=interval)
89
+ defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, rdamping=rdamping, true_damping=true_damping, order=order, U_beta=U_beta, S_beta=S_beta, normalize=normalize, centralize=centralize)
90
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner, update_freq=interval)
225
91
 
226
92
  @torch.no_grad
227
- def update_tensor(self, tensor, param, grad, state, settings):
93
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
228
94
  order = settings['order']
229
95
  history_size = settings['history_size']
230
96
  update_freq = settings['update_freq']
231
97
  damping = settings['damping']
232
- A_beta = settings['A_beta']
233
- B_beta = settings['B_beta']
234
- solver: _Solver = settings['solver']
98
+ rdamping = settings['rdamping']
99
+ true_damping = settings['true_damping']
100
+ U_beta = settings['U_beta']
101
+ S_beta = settings['S_beta']
102
+ normalize = settings['normalize']
103
+ centralize = settings['centralize']
235
104
 
236
105
  if 'history' not in state: state['history'] = deque(maxlen=history_size)
237
106
  history = state['history']
238
107
 
239
- if order == 1: history.append(tensor.clone().view(-1))
108
+ if order == 1:
109
+ t = tensor.clone().view(-1)
110
+ if centralize: t -= t.mean()
111
+ if normalize: t /= torch.linalg.vector_norm(t).clip(min=1e-8) # pylint:disable=not-callable
112
+ history.append(t)
240
113
  else:
241
114
 
242
115
  # if order=2, history is of gradient differences, order 3 is differences between differences, etc
243
- # normalized by parameter differences
116
+ # scaled by parameter differences
244
117
  cur_p = param.clone()
245
118
  cur_g = tensor.clone()
246
119
  for i in range(1, order):
@@ -257,32 +130,34 @@ class SpectralPreconditioner(TensorwisePreconditioner):
257
130
  cur_g = y_k
258
131
 
259
132
  if i == order - 1:
260
- cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
133
+ if centralize: cur_g = cur_g - cur_g.mean()
134
+ if normalize: cur_g = cur_g / torch.linalg.vector_norm(cur_g).clip(min=1e-8) # pylint:disable=not-callable
135
+ else: cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
261
136
  history.append(cur_g.view(-1))
262
137
 
263
138
  step = state.get('step', 0)
264
139
  if step % update_freq == 0 and len(history) != 0:
265
- A, B = solver.update(history, damping=damping)
266
- maybe_lerp_(state, A_beta, 'A', A)
267
- maybe_lerp_(state, B_beta, 'B', B)
140
+ U, S_inv = spectral_update(history, damping=damping, rdamping=rdamping, true_damping=true_damping)
141
+ maybe_lerp_(state, U_beta, 'U', U)
142
+ maybe_lerp_(state, S_beta, 'S_inv', S_inv)
268
143
 
269
144
  if len(history) != 0:
270
145
  state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
271
146
 
272
147
  @torch.no_grad
273
- def apply_tensor(self, tensor, param, grad, state, settings):
148
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
274
149
  history_size = settings['history_size']
275
- solver: _Solver = settings['solver']
276
150
 
277
- A = state.get('A', None)
278
- if A is None:
151
+ U = state.get('U', None)
152
+ if U is None:
279
153
  # make a conservative step to avoid issues due to different GD scaling
280
154
  return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
281
155
 
282
- B = state['B']
283
- update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
156
+ S_inv = state['S_inv']
157
+ update = spectral_apply(tensor.view(-1), U, S_inv).view_as(tensor)
284
158
 
285
159
  n = len(state['history'])
286
- if n != history_size: update.mul_(n/history_size)
160
+ mh = min(history_size, 10)
161
+ if n <= mh: update.mul_(n/mh)
287
162
  return update
288
163
 
@@ -0,0 +1,111 @@
1
+ # idea https://arxiv.org/pdf/2212.09841
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Chainable, Module, apply_transform
10
+ from ...utils import TensorList, vec_to_tensors
11
+ from ...utils.derivatives import (
12
+ hessian_list_to_mat,
13
+ hessian_mat,
14
+ hvp,
15
+ hvp_fd_central,
16
+ hvp_fd_forward,
17
+ jacobian_and_hessian_wrt,
18
+ )
19
+
20
+
21
+ class StructuredNewton(Module):
22
+ """TODO. Please note that this is experimental and isn't guaranteed to work.
23
+ Args:
24
+ structure (str, optional): structure.
25
+ reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
+ hvp_method (str):
27
+ how to calculate hvp_method. Defaults to "autograd".
28
+ inner (Chainable | None, optional): inner modules. Defaults to None.
29
+
30
+ """
31
+ def __init__(
32
+ self,
33
+ structure: Literal[
34
+ "diagonal",
35
+ "diagonal1",
36
+ "diagonal_abs",
37
+ "tridiagonal",
38
+ "circulant",
39
+ "toeplitz",
40
+ "toeplitz_like",
41
+ "hankel",
42
+ "rank1",
43
+ "rank2", # any rank
44
+ ]
45
+ | str = "diagonal",
46
+ reg: float = 1e-6,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ inner: Chainable | None = None,
50
+ ):
51
+ defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
+ super().__init__(defaults)
53
+
54
+ if inner is not None:
55
+ self.set_child('inner', inner)
56
+
57
+ @torch.no_grad
58
+ def step(self, var):
59
+ params = TensorList(var.params)
60
+ closure = var.closure
61
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
62
+
63
+ settings = self.settings[params[0]]
64
+ reg = settings['reg']
65
+ hvp_method = settings['hvp_method']
66
+ structure = settings['structure']
67
+ h = settings['h']
68
+
69
+ # ------------------------ calculate grad and hessian ------------------------ #
70
+ if hvp_method == 'autograd':
71
+ grad = var.get_grad(create_graph=True)
72
+ def Hvp_fn1(x):
73
+ return hvp(params, grad, x, retain_graph=True)
74
+ Hvp_fn = Hvp_fn1
75
+
76
+ elif hvp_method == 'forward':
77
+ grad = var.get_grad()
78
+ def Hvp_fn2(x):
79
+ return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
+ Hvp_fn = Hvp_fn2
81
+
82
+ elif hvp_method == 'central':
83
+ grad = var.get_grad()
84
+ def Hvp_fn3(x):
85
+ return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
+ Hvp_fn = Hvp_fn3
87
+
88
+ else: raise ValueError(hvp_method)
89
+
90
+ # -------------------------------- inner step -------------------------------- #
91
+ update = var.get_update()
92
+ if 'inner' in self.children:
93
+ update = apply_transform(self.children['inner'], update, params=params, grads=grad, var=var)
94
+
95
+ # hessian
96
+ if structure.startswith('diagonal'):
97
+ H = Hvp_fn([torch.ones_like(p) for p in params])
98
+ if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
+ if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
+ torch._foreach_add_(H, reg)
101
+ torch._foreach_div_(update, H)
102
+ var.update = update
103
+ return var
104
+
105
+ # hessian
106
+ raise NotImplementedError(structure)
107
+
108
+
109
+
110
+
111
+
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  # import torchzero as tz
7
7
 
8
- from ...core import Transform, Chainable, apply
8
+ from ...core import Transform, Chainable, apply_transform
9
9
  from ...utils.linalg import inv_sqrt_2x2, matrix_power_eigh, gram_schmidt
10
10
  from ...utils import TensorList, vec_to_tensors_
11
11
 
@@ -38,15 +38,15 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """full matrix rmsprop in random slowly changing subspace"""
41
+ """Whitens in random slowly changing subspace. Please note that this is experimental and isn't guaranteed to work."""
42
42
  def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
43
  defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
44
  super().__init__(defaults, uses_grad=False)
45
45
 
46
46
  if inner is not None: self.set_child('inner', inner)
47
47
 
48
- def transform(self, tensors, params, grads, vars):
49
- settings = self.settings[params[0]]
48
+ def apply(self, tensors, params, grads, loss, states, settings):
49
+ settings = settings[0]
50
50
  g = torch.cat([t.view(-1) for t in tensors])
51
51
  k = settings['k']
52
52
  beta = settings['beta']
@@ -65,7 +65,7 @@ class RandomSubspacePreconditioning(Transform):
65
65
  update_subspace_preconditioner_(g, basis, accumulator, beta)
66
66
 
67
67
  if 'inner' in self.children:
68
- tensors = apply(self.children['inner'], tensors, params, grads, vars)
68
+ tensors = apply_transform(self.children['inner'], tensors, params, grads)
69
69
  g = torch.cat([t.view(-1) for t in tensors])
70
70
 
71
71
  try:
@@ -78,9 +78,12 @@ class RandomSubspacePreconditioning(Transform):
78
78
 
79
79
 
80
80
  class HistorySubspacePreconditioning(Transform):
81
- """full matrix rmsprop in subspace spanned by history of gradient differences
81
+ """Whitens in subspace spanned by history of gradient differences.
82
+ Please note that this is experimental and isn't guaranteed to work.
82
83
 
83
- basis_beta is how much basis is allowed to change, and beta is for preconditioner itself in the basis.
84
+ Args:
85
+ beta - for preconditioner itself in the basis.
86
+ basis_beta - how much basis is allowed to change.
84
87
  """
85
88
  def __init__(self, k: int, beta: float | None = 0.99, basis_beta=0.99, inner: Chainable | None = None):
86
89
  defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
@@ -88,8 +91,8 @@ class HistorySubspacePreconditioning(Transform):
88
91
 
89
92
  if inner is not None: self.set_child('inner', inner)
90
93
 
91
- def transform(self, tensors, params, grads, vars):
92
- settings = self.settings[params[0]]
94
+ def apply(self, tensors, params, grads, loss, states, settings):
95
+ settings = settings[0]
93
96
 
94
97
  g = torch.cat([t.view(-1) for t in tensors])
95
98
  k = settings['k']
@@ -122,7 +125,7 @@ class HistorySubspacePreconditioning(Transform):
122
125
  update_subspace_preconditioner_(g, basis, accumulator, beta)
123
126
 
124
127
  if 'inner' in self.children:
125
- tensors = apply(self.children['inner'], tensors, params, grads, vars)
128
+ tensors = apply_transform(self.children['inner'], tensors, params, grads)
126
129
  g = torch.cat([t.view(-1) for t in tensors])
127
130
 
128
131
  try:
@@ -0,0 +1,38 @@
1
+ from collections import deque
2
+
3
+ import torch
4
+
5
+ from ...core import Chainable, TensorwiseTransform
6
+ from ...utils.linalg import matrix_power_eigh
7
+
8
+
9
+ class TAda(TensorwiseTransform):
10
+ """3rd order whitening (maybe normalizes skewness). Please note that this is experimental and isn't guaranteed to work."""
11
+ def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
12
+ defaults = dict(history_size=history_size, reg=reg)
13
+ super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
14
+
15
+ @torch.no_grad
16
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
17
+ reg = settings['reg']
18
+ if 'history' not in state:
19
+ state['history'] = deque(maxlen=settings['history_size'])
20
+
21
+ g = tensor.view(-1)
22
+ history = state['history']
23
+ history.append(g.clone())
24
+
25
+ I = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype).mul_(reg)
26
+ g_k = history[0]
27
+ outer = torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
28
+ if len(history) > 1:
29
+ for g_k in list(history)[1:]:
30
+ outer += torch.outer(g_k, g_k).mul_(torch.dot(g_k, g).clip(min=reg))
31
+
32
+ state['outer'] = outer.add_(I)
33
+
34
+ @torch.no_grad
35
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
36
+ outer = state['outer']
37
+ P = matrix_power_eigh(outer, -1/2)
38
+ return (P @ tensor.ravel()).view_as(tensor)
@@ -93,14 +93,14 @@ class FDM(GradApproximator):
93
93
  Args:
94
94
  h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
95
95
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
96
- target (GradTarget, optional): what to set on vars. Defaults to 'closure'.
96
+ target (GradTarget, optional): what to set on var. Defaults to 'closure'.
97
97
  """
98
98
  def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
99
99
  defaults = dict(h=h, formula=formula)
100
100
  super().__init__(defaults, target=target)
101
101
 
102
102
  @torch.no_grad
103
- def approximate(self, closure, params, loss, vars):
103
+ def approximate(self, closure, params, loss, var):
104
104
  grads = []
105
105
  loss_approx = None
106
106
 
@@ -17,13 +17,13 @@ class ForwardGradient(RandomizedFDM):
17
17
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
18
18
  distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
19
19
  beta (float, optional):
20
- if not 0, acts as momentum on gradient samples, making the subspace spanned by them change slowly. Defaults to 0.
20
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
21
21
  pre_generate (bool, optional):
22
- whether to pre-generate gradient samples before each step. Defaults to True.
22
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
23
23
  jvp_method (str, optional):
24
- how to calculate jacobian vector product, note that with `forward` and 'central' this is identical to randomized finite difference. Defaults to 'autograd'.
24
+ how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
25
25
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
26
- target (GradTarget, optional): what to set on vars. Defaults to "closure".
26
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
27
27
  """
28
28
  PRE_MULTIPLY_BY_H = False
29
29
  def __init__(
@@ -41,7 +41,7 @@ class ForwardGradient(RandomizedFDM):
41
41
  self.defaults['jvp_method'] = jvp_method
42
42
 
43
43
  @torch.no_grad
44
- def approximate(self, closure, params, loss, vars):
44
+ def approximate(self, closure, params, loss, var):
45
45
  params = TensorList(params)
46
46
  loss_approx = None
47
47