torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,288 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- import math
3
- from collections import deque
4
- from typing import Literal, Any
5
-
6
- import torch
7
- from ...core import Chainable, TensorwisePreconditioner
8
- from ...utils.linalg.matrix_funcs import matrix_power_eigh
9
- from ...utils.linalg.svd import randomized_svd
10
- from ...utils.linalg.qr import qr_householder
11
-
12
-
13
- class _Solver:
14
- @abstractmethod
15
- def update(self, history: deque[torch.Tensor], damping: float | None) -> tuple[Any, Any]:
16
- """returns stuff for apply"""
17
- @abstractmethod
18
- def apply(self, __g: torch.Tensor, __A:torch.Tensor, __B:torch.Tensor) -> torch.Tensor:
19
- """apply preconditioning to tensor"""
20
-
21
- class _SVDSolver(_Solver):
22
- def __init__(self, driver=None): self.driver=driver
23
- def update(self, history, damping):
24
- M_hist = torch.stack(tuple(history), dim=1)
25
- device = None # driver is CUDA only
26
- if self.driver is not None:
27
- device = M_hist.device
28
- M_hist = M_hist.cuda()
29
-
30
- try:
31
- U, S, _ = torch.linalg.svd(M_hist, full_matrices=False, driver=self.driver) # pylint:disable=not-callable
32
-
33
- if self.driver is not None:
34
- U = U.to(device); S = S.to(device)
35
-
36
- if damping is not None and damping != 0: S.add_(damping)
37
- return U, S
38
-
39
- except torch.linalg.LinAlgError:
40
- return None, None
41
-
42
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
43
- Utg = (U.T @ g).div_(S)
44
- return U @ Utg
45
-
46
- class _SVDLowRankSolver(_Solver):
47
- def __init__(self, q: int = 6, niter: int = 2): self.q, self.niter = q, niter
48
- def update(self, history, damping):
49
- M_hist = torch.stack(tuple(history), dim=1)
50
- try:
51
- U, S, _ = torch.svd_lowrank(M_hist, q=self.q, niter=self.niter)
52
- if damping is not None and damping != 0: S.add_(damping)
53
- return U, S
54
- except torch.linalg.LinAlgError:
55
- return None, None
56
-
57
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
58
- Utg = (U.T @ g).div_(S)
59
- return U @ Utg
60
-
61
- class _RandomizedSVDSolver(_Solver):
62
- def __init__(self, k: int = 3, driver: str | None = 'gesvda'):
63
- self.driver = driver
64
- self.k = k
65
-
66
- def update(self, history, damping):
67
- M_hist = torch.stack(tuple(history), dim=1)
68
- device = None # driver is CUDA only
69
- if self.driver is not None:
70
- device = M_hist.device
71
- M_hist = M_hist.cuda()
72
-
73
- try:
74
- U, S, _ = randomized_svd(M_hist, k=self.k, driver=self.driver)
75
-
76
- if self.driver is not None:
77
- U = U.to(device); S = S.to(device)
78
-
79
- if damping is not None and damping != 0: S.add_(damping)
80
- return U, S
81
-
82
- except torch.linalg.LinAlgError:
83
- return None, None
84
-
85
- def apply(self, g: torch.Tensor, U: torch.Tensor, S: torch.Tensor):
86
- Utg = (U.T @ g).div_(S)
87
- return U @ Utg
88
-
89
- class _QRDiagonalSolver(_Solver):
90
- def __init__(self, sqrt=True): self.sqrt = sqrt
91
- def update(self, history, damping):
92
- M_hist = torch.stack(tuple(history), dim=1)
93
- try:
94
- Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
95
- R_diag = R.diag().abs()
96
- if damping is not None and damping != 0: R_diag.add_(damping)
97
- if self.sqrt: R_diag.sqrt_()
98
- return Q, R_diag
99
- except torch.linalg.LinAlgError:
100
- return None, None
101
-
102
- def apply(self, g: torch.Tensor, Q: torch.Tensor, R_diag: torch.Tensor):
103
- Qtg = (Q.T @ g).div_(R_diag)
104
- return Q @ Qtg
105
-
106
- class _QRSolver(_Solver):
107
- def __init__(self, sqrt=True): self.sqrt = sqrt
108
- def update(self, history, damping):
109
- M_hist = torch.stack(tuple(history), dim=1)
110
- try:
111
- # Q: d x k, R: k x k
112
- Q, R = torch.linalg.qr(M_hist, mode='reduced') # pylint:disable=not-callable
113
- A = R @ R.T
114
- if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
115
- if self.sqrt: A = matrix_power_eigh(A, 0.5)
116
- return Q, A
117
- except (torch.linalg.LinAlgError):
118
- return None,None
119
-
120
- def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
121
- g_proj = Q.T @ g
122
- y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
123
- return Q @ y
124
-
125
- class _QRHouseholderSolver(_Solver):
126
- def __init__(self, sqrt=True): self.sqrt = sqrt
127
- def update(self, history, damping):
128
- M_hist = torch.stack(tuple(history), dim=1)
129
- try:
130
- # Q: d x k, R: k x k
131
- Q, R = qr_householder(M_hist, mode='reduced') # pylint:disable=not-callable
132
- A = R @ R.T
133
- if damping is not None and damping != 0: A.diagonal(dim1=-2, dim2=-1).add_(damping)
134
- if self.sqrt: A = matrix_power_eigh(A, 0.5)
135
- return Q, A
136
- except (torch.linalg.LinAlgError):
137
- return None,None
138
-
139
- def apply(self, g: torch.Tensor, Q: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
140
- g_proj = Q.T @ g
141
- y, _ = torch.linalg.solve_ex(A, g_proj) # pylint:disable=not-callable
142
- return Q @ y
143
-
144
-
145
- class _EighSolver(_Solver):
146
- def __init__(self, sqrt=True):
147
- self.sqrt = sqrt
148
-
149
- def update(self, history, damping):
150
- M_hist = torch.stack(tuple(history), dim=1)
151
- grams = M_hist @ M_hist.T # (d, d)
152
- if damping is not None and damping != 0: grams.diagonal(dim1=-2, dim2=-1).add_(damping)
153
- try:
154
- L, Q = torch.linalg.eigh(grams) # L: (d,), Q: (d, d) # pylint:disable=not-callable
155
- L = L.abs().clamp_(min=1e-12)
156
- if self.sqrt: L = L.sqrt()
157
- return Q, L
158
- except torch.linalg.LinAlgError:
159
- return None, None
160
-
161
- def apply(self, g: torch.Tensor, Q: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
162
- Qtg = (Q.T @ g).div_(L)
163
- return Q @ Qtg
164
-
165
-
166
- SOLVERS = {
167
- "svd": _SVDSolver(), # fallbacks on "gesvd" which basically takes ages or just hangs completely
168
- "svd_gesvdj": _SVDSolver("gesvdj"), # no fallback on slow "gesvd"
169
- "svd_gesvda": _SVDSolver("gesvda"), # approximate method for wide matrices, sometimes better sometimes worse but faster
170
- "svd_lowrank": _SVDLowRankSolver(), # maybe need to tune parameters for this, with current ones its slower and worse
171
- "randomized_svd2": _RandomizedSVDSolver(2),
172
- "randomized_svd3": _RandomizedSVDSolver(3),
173
- "randomized_svd4": _RandomizedSVDSolver(4),
174
- "randomized_svd5": _RandomizedSVDSolver(5),
175
- "eigh": _EighSolver(), # this is O(n**2) storage, but is this more accurate?
176
- "qr": _QRSolver(),
177
- "qr_householder": _QRHouseholderSolver(), # this is slower... but maybe it won't freeze? I think svd_gesvda is better
178
- "qrdiag": _QRDiagonalSolver(),
179
- }
180
-
181
- def maybe_lerp_(state_: dict, beta: float | None, key, value: Any):
182
- if (key not in state_) or (beta is None) or (not isinstance(value, torch.Tensor)): state_[key] = value
183
- else:
184
- if state_[key].shape != value.shape: state_[key] = value
185
- else: state_[key].lerp_(value, 1-beta)
186
-
187
- class SpectralPreconditioner(TensorwisePreconditioner):
188
- """Whitening preconditioner via SVD on history of past gradients or gradient differences scaled by parameter differences.
189
-
190
- Args:
191
- history_size (int, optional): number of past gradients to store for preconditioning. Defaults to 10.
192
- update_freq (int, optional): how often to re-compute the preconditioner. Defaults to 1.
193
- damping (float, optional): damping term, makes it closer to GD. Defaults to 1e-7.
194
- order (int, optional):
195
- whitening order, 1 approximates FIM (maybe), 2 - hessian (maybe), 3+ - god knows what.
196
- solver (str, optional): what to use for whitening. Defaults to 'svd'.
197
- A_beta (float | None, optional):
198
- beta for U (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
199
- B_beta (float | None, optional):
200
- beta for S (in SVD and other letters in other solvers) (probably a bad idea). Defaults to None.
201
- interval (int, optional): How often to update history. Defaults to 1 (every step).
202
- concat_params (bool, optional):
203
- whether to apply preconditioning to each tensor (False, default) or to all tensors concatenated into a vector (True). Latter will be slower but captures interactions between layers. Defaults to True.
204
- scale_first (bool, optional): makes first step small, usually not needed. Defaults to False.
205
- inner (Chainable | None, optional): Inner modules applied after updating preconditioner and before applying it. Defaults to None.
206
- """
207
- def __init__(
208
- self,
209
- history_size: int = 10,
210
- update_freq: int = 1,
211
- damping: float = 1e-12,
212
- order: int = 1,
213
- solver: Literal['svd', 'svd_gesvdj', 'svd_gesvda', 'svd_lowrank', 'eigh', 'qr', 'qrdiag', 'qr_householder'] | _Solver | str = 'svd_gesvda',
214
- A_beta: float | None = None,
215
- B_beta: float | None = None,
216
- interval: int = 1,
217
- concat_params: bool = False,
218
- scale_first: bool = False,
219
- inner: Chainable | None = None,
220
- ):
221
- if isinstance(solver, str): solver = SOLVERS[solver]
222
- # history is still updated each step so Precondition's update_freq has different meaning
223
- defaults = dict(history_size=history_size, update_freq=update_freq, damping=damping, order=order, A_beta=A_beta, B_beta=B_beta, solver=solver)
224
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, scale_first=scale_first, inner=inner, update_freq=interval)
225
-
226
- @torch.no_grad
227
- def update_tensor(self, tensor, param, grad, state, settings):
228
- order = settings['order']
229
- history_size = settings['history_size']
230
- update_freq = settings['update_freq']
231
- damping = settings['damping']
232
- A_beta = settings['A_beta']
233
- B_beta = settings['B_beta']
234
- solver: _Solver = settings['solver']
235
-
236
- if 'history' not in state: state['history'] = deque(maxlen=history_size)
237
- history = state['history']
238
-
239
- if order == 1: history.append(tensor.clone().view(-1))
240
- else:
241
-
242
- # if order=2, history is of gradient differences, order 3 is differences between differences, etc
243
- # normalized by parameter differences
244
- cur_p = param.clone()
245
- cur_g = tensor.clone()
246
- for i in range(1, order):
247
- if f'prev_g_{i}' not in state:
248
- state[f'prev_p_{i}'] = cur_p
249
- state[f'prev_g_{i}'] = cur_g
250
- break
251
-
252
- s_k = cur_p - state[f'prev_p_{i}']
253
- y_k = cur_g - state[f'prev_g_{i}']
254
- state[f'prev_p_{i}'] = cur_p
255
- state[f'prev_g_{i}'] = cur_g
256
- cur_p = s_k
257
- cur_g = y_k
258
-
259
- if i == order - 1:
260
- cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
261
- history.append(cur_g.view(-1))
262
-
263
- step = state.get('step', 0)
264
- if step % update_freq == 0 and len(history) != 0:
265
- A, B = solver.update(history, damping=damping)
266
- maybe_lerp_(state, A_beta, 'A', A)
267
- maybe_lerp_(state, B_beta, 'B', B)
268
-
269
- if len(history) != 0:
270
- state['step'] = step + 1 # do not increment if no history (gathering s_ks and y_ks)
271
-
272
- @torch.no_grad
273
- def apply_tensor(self, tensor, param, grad, state, settings):
274
- history_size = settings['history_size']
275
- solver: _Solver = settings['solver']
276
-
277
- A = state.get('A', None)
278
- if A is None:
279
- # make a conservative step to avoid issues due to different GD scaling
280
- return tensor.clip_(-0.1, 0.1) # pyright:ignore[reportArgumentType]
281
-
282
- B = state['B']
283
- update = solver.apply(tensor.view(-1), A, B).view_as(tensor)
284
-
285
- n = len(state['history'])
286
- if n != history_size: update.mul_(n/history_size)
287
- return update
288
-
@@ -1,111 +0,0 @@
1
- # idea https://arxiv.org/pdf/2212.09841
2
- import warnings
3
- from collections.abc import Callable
4
- from functools import partial
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from ...core import Chainable, Module, apply
10
- from ...utils import TensorList, vec_to_tensors
11
- from ...utils.derivatives import (
12
- hessian_list_to_mat,
13
- hessian_mat,
14
- hvp,
15
- hvp_fd_central,
16
- hvp_fd_forward,
17
- jacobian_and_hessian_wrt,
18
- )
19
-
20
-
21
- class StructuredNewton(Module):
22
- """TODO
23
- Args:
24
- structure (str, optional): structure.
25
- reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
26
- hvp_method (str):
27
- how to calculate hvp_method. Defaults to "autograd".
28
- inner (Chainable | None, optional): inner modules. Defaults to None.
29
-
30
- """
31
- def __init__(
32
- self,
33
- structure: Literal[
34
- "diagonal",
35
- "diagonal1",
36
- "diagonal_abs",
37
- "tridiagonal",
38
- "circulant",
39
- "toeplitz",
40
- "toeplitz_like",
41
- "hankel",
42
- "rank1",
43
- "rank2", # any rank
44
- ]
45
- | str = "diagonal",
46
- reg: float = 1e-6,
47
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
- h: float = 1e-3,
49
- inner: Chainable | None = None,
50
- ):
51
- defaults = dict(reg=reg, hvp_method=hvp_method, structure=structure, h=h)
52
- super().__init__(defaults)
53
-
54
- if inner is not None:
55
- self.set_child('inner', inner)
56
-
57
- @torch.no_grad
58
- def step(self, vars):
59
- params = TensorList(vars.params)
60
- closure = vars.closure
61
- if closure is None: raise RuntimeError('NewtonCG requires closure')
62
-
63
- settings = self.settings[params[0]]
64
- reg = settings['reg']
65
- hvp_method = settings['hvp_method']
66
- structure = settings['structure']
67
- h = settings['h']
68
-
69
- # ------------------------ calculate grad and hessian ------------------------ #
70
- if hvp_method == 'autograd':
71
- grad = vars.get_grad(create_graph=True)
72
- def Hvp_fn1(x):
73
- return hvp(params, grad, x, retain_graph=True)
74
- Hvp_fn = Hvp_fn1
75
-
76
- elif hvp_method == 'forward':
77
- grad = vars.get_grad()
78
- def Hvp_fn2(x):
79
- return hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1]
80
- Hvp_fn = Hvp_fn2
81
-
82
- elif hvp_method == 'central':
83
- grad = vars.get_grad()
84
- def Hvp_fn3(x):
85
- return hvp_fd_central(closure, params, x, h=h, normalize=True)[1]
86
- Hvp_fn = Hvp_fn3
87
-
88
- else: raise ValueError(hvp_method)
89
-
90
- # -------------------------------- inner step -------------------------------- #
91
- update = vars.get_update()
92
- if 'inner' in self.children:
93
- update = apply(self.children['inner'], update, params=params, grads=grad, vars=vars)
94
-
95
- # hessian
96
- if structure.startswith('diagonal'):
97
- H = Hvp_fn([torch.ones_like(p) for p in params])
98
- if structure == 'diagonal1': torch._foreach_clamp_min_(H, 1)
99
- if structure == 'diagonal_abs': torch._foreach_abs_(H)
100
- torch._foreach_add_(H, reg)
101
- torch._foreach_div_(update, H)
102
- vars.update = update
103
- return vars
104
-
105
- # hessian
106
- raise NotImplementedError(structure)
107
-
108
-
109
-
110
-
111
-
@@ -1,136 +0,0 @@
1
- import warnings
2
- from functools import partial
3
- from typing import Literal
4
- from collections.abc import Callable
5
- import torch
6
-
7
- from ...core import Chainable, apply, Module
8
- from ...utils import vec_to_tensors, TensorList
9
- from ...utils.derivatives import (
10
- hessian_list_to_mat,
11
- hessian_mat,
12
- jacobian_and_hessian_wrt,
13
- )
14
- from ..second_order.newton import lu_solve, cholesky_solve, least_squares_solve
15
-
16
- def tropical_sum(x, dim): return torch.amax(x, dim=dim)
17
- def tropical_mul(x, y): return x+y
18
-
19
- def tropical_matmul(x: torch.Tensor, y: torch.Tensor):
20
- # this imlements matmul by calling mul and sum
21
-
22
- x_squeeze = False
23
- y_squeeze = False
24
-
25
- if x.ndim == 1:
26
- x_squeeze = True
27
- x = x.unsqueeze(0)
28
-
29
- if y.ndim == 1:
30
- y_squeeze = True
31
- y = y.unsqueeze(1)
32
-
33
- res = tropical_sum(tropical_mul(x.unsqueeze(-1), y.unsqueeze(-3)), dim = -2)
34
-
35
- if x_squeeze: res = res.squeeze(-2)
36
- if y_squeeze: res = res.squeeze(-1)
37
-
38
- return res
39
-
40
- def tropical_dot(x:torch.Tensor, y:torch.Tensor):
41
- assert x.ndim == 1 and y.ndim == 1
42
- return tropical_matmul(x.unsqueeze(0), y.unsqueeze(1))
43
-
44
- def tropical_outer(x:torch.Tensor, y:torch.Tensor):
45
- assert x.ndim == 1 and y.ndim == 1
46
- return tropical_matmul(x.unsqueeze(1), y.unsqueeze(0))
47
-
48
-
49
- def tropical_solve(A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
50
- r = b.unsqueeze(1) - A
51
- return r.amin(dim=-2)
52
-
53
- def tropical_solve_and_reconstruct(A: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
54
- r = b.unsqueeze(1) - A
55
- x = r.amin(dim=-2)
56
- b_hat = tropical_matmul(A, x)
57
- return x, b_hat
58
-
59
- def tikhonov(H: torch.Tensor, reg: float):
60
- if reg!=0: H += torch.eye(H.size(-1), dtype=H.dtype, device=H.device) * reg
61
- return H
62
-
63
-
64
- class TropicalNewton(Module):
65
- """suston"""
66
- def __init__(
67
- self,
68
- reg: float | None = None,
69
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
70
- vectorize: bool = True,
71
- interpolate:bool=False,
72
- inner: Chainable | None = None,
73
- ):
74
- defaults = dict(reg=reg, hessian_method=hessian_method, vectorize=vectorize, interpolate=interpolate)
75
- super().__init__(defaults)
76
-
77
- if inner is not None:
78
- self.set_child('inner', inner)
79
-
80
- @torch.no_grad
81
- def step(self, vars):
82
- params = TensorList(vars.params)
83
- closure = vars.closure
84
- if closure is None: raise RuntimeError('NewtonCG requires closure')
85
-
86
- settings = self.settings[params[0]]
87
- reg = settings['reg']
88
- hessian_method = settings['hessian_method']
89
- vectorize = settings['vectorize']
90
- interpolate = settings['interpolate']
91
-
92
- # ------------------------ calculate grad and hessian ------------------------ #
93
- if hessian_method == 'autograd':
94
- with torch.enable_grad():
95
- loss = vars.loss = vars.loss_approx = closure(False)
96
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
97
- g_list = [t[0] for t in g_list] # remove leading dim from loss
98
- vars.grad = g_list
99
- H = hessian_list_to_mat(H_list)
100
-
101
- elif hessian_method in ('func', 'autograd.functional'):
102
- strat = 'forward-mode' if vectorize else 'reverse-mode'
103
- with torch.enable_grad():
104
- g_list = vars.get_grad(retain_graph=True)
105
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
106
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
107
-
108
- else:
109
- raise ValueError(hessian_method)
110
-
111
- # -------------------------------- inner step -------------------------------- #
112
- if 'inner' in self.children:
113
- g_list = apply(self.children['inner'], list(g_list), params=params, grads=list(g_list), vars=vars)
114
- g = torch.cat([t.view(-1) for t in g_list])
115
-
116
- # ------------------------------- regulazition ------------------------------- #
117
- if reg is not None: H = tikhonov(H, reg)
118
-
119
- # ----------------------------------- solve ---------------------------------- #
120
- tropical_update, g_hat = tropical_solve_and_reconstruct(H, g)
121
-
122
- g_norm = torch.linalg.vector_norm(g) # pylint:disable=not-callable
123
- abs_error = torch.linalg.vector_norm(g-g_hat) # pylint:disable=not-callable
124
- rel_error = abs_error/g_norm.clip(min=1e-8)
125
-
126
- if interpolate:
127
- if rel_error > 1e-8:
128
-
129
- update = cholesky_solve(H, g)
130
- if update is None: update = lu_solve(H, g)
131
- if update is None: update = least_squares_solve(H, g)
132
-
133
- tropical_update.lerp_(update.ravel(), rel_error.clip(max=1))
134
-
135
- vars.update = vec_to_tensors(tropical_update, params)
136
- return vars
@@ -1,2 +0,0 @@
1
- from .lr import LR, StepSize, Warmup
2
- from .step_size import PolyakStepSize, RandomStepSize
@@ -1,59 +0,0 @@
1
- import torch
2
-
3
- from ...core import Transform
4
- from ...utils import NumberList, TensorList, generic_eq
5
-
6
-
7
- def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
8
- """multiplies by lr if lr is not 1"""
9
- if generic_eq(lr, 1): return tensors
10
- if inplace: return tensors.mul_(lr)
11
- return tensors * lr
12
-
13
- class LR(Transform):
14
- def __init__(self, lr: float):
15
- defaults=dict(lr=lr)
16
- super().__init__(defaults, uses_grad=False)
17
-
18
- @torch.no_grad
19
- def transform(self, tensors, params, grads, vars):
20
- return lazy_lr(TensorList(tensors), lr=self.get_settings('lr', params=params), inplace=True)
21
-
22
- class StepSize(Transform):
23
- """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
24
- def __init__(self, step_size: float, key = 'step_size'):
25
- defaults={"key": key, key: step_size}
26
- super().__init__(defaults, uses_grad=False)
27
-
28
- @torch.no_grad
29
- def transform(self, tensors, params, grads, vars):
30
- lrs = []
31
- for p in params:
32
- settings = self.settings[p]
33
- lrs.append(settings[settings['key']])
34
- return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
35
-
36
-
37
- def warmup(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
38
- """returns warm up lr scalar"""
39
- if step > steps: return end_lr
40
- return start_lr + (end_lr - start_lr) * (step / steps)
41
-
42
- class Warmup(Transform):
43
- def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
44
- defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
45
- super().__init__(defaults, uses_grad=False)
46
-
47
- @torch.no_grad
48
- def transform(self, tensors, params, grads, vars):
49
- start_lr, end_lr = self.get_settings('start_lr', 'end_lr', params=params, cls = NumberList)
50
- num_steps = self.settings[params[0]]['steps']
51
- step = self.global_state.get('step', 0)
52
-
53
- target = lazy_lr(
54
- TensorList(tensors),
55
- lr=warmup(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
56
- inplace=True
57
- )
58
- self.global_state['step'] = step + 1
59
- return target
@@ -1,97 +0,0 @@
1
- import random
2
- from typing import Any
3
-
4
- import torch
5
-
6
- from ...core import Transform
7
- from ...utils import TensorList, NumberList
8
-
9
-
10
- class PolyakStepSize(Transform):
11
- """Polyak step-size.
12
-
13
- Args:
14
- max (float | None, optional): maximum possible step size. Defaults to None.
15
- min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
- use_grad (bool, optional):
17
- if True, uses dot product of update and gradient to compute the step size.
18
- Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
19
- Defaults to True.
20
- parameterwise (bool, optional):
21
- if True, calculate Polyak step-size for each parameter separately,
22
- if False calculate one global step size for all parameters. Defaults to False.
23
- alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
24
- """
25
- def __init__(self, max: float | None = None, min_obj_value: float = 0, use_grad=True, parameterwise=False, alpha: float = 1):
26
-
27
- defaults = dict(alpha=alpha, max=max, min_obj_value=min_obj_value, use_grad=use_grad, parameterwise=parameterwise)
28
- super().__init__(defaults, uses_grad=use_grad)
29
-
30
- @torch.no_grad
31
- def transform(self, tensors, params, grads, vars):
32
- loss = vars.get_loss(False)
33
- assert grads is not None
34
- tensors = TensorList(tensors)
35
- grads = TensorList(grads)
36
- alpha = self.get_settings('alpha', params=params, cls=NumberList)
37
- settings = self.settings[params[0]]
38
- parameterwise = settings['parameterwise']
39
- use_grad = settings['use_grad']
40
- max = settings['max']
41
- min_obj_value = settings['min_obj_value']
42
-
43
- if parameterwise:
44
- if use_grad: denom = (tensors * grads).sum()
45
- else: denom = tensors.pow(2).sum()
46
- polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
47
- polyak_step_size = polyak_step_size.where(denom != 0, 0)
48
- if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
49
-
50
- else:
51
- if use_grad: denom = tensors.dot(grads)
52
- else: denom = tensors.dot(tensors)
53
- if denom == 0: polyak_step_size = 0 # we converged
54
- else: polyak_step_size = (loss - min_obj_value) / denom
55
-
56
- if max is not None:
57
- if polyak_step_size > max: polyak_step_size = max
58
-
59
- tensors.mul_(alpha * polyak_step_size)
60
- return tensors
61
-
62
-
63
-
64
- class RandomStepSize(Transform):
65
- """Uses random global step size from `low` to `high`.
66
-
67
- Args:
68
- low (float, optional): minimum learning rate. Defaults to 0.
69
- high (float, optional): maximum learning rate. Defaults to 1.
70
- parameterwise (bool, optional):
71
- if True, generate random step size for each parameter separately,
72
- if False generate one global random step size. Defaults to False.
73
- """
74
- def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
75
- defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
76
- super().__init__(defaults, uses_grad=False)
77
-
78
- @torch.no_grad
79
- def transform(self, tensors, params, grads, vars):
80
- settings = self.settings[params[0]]
81
- parameterwise = settings['parameterwise']
82
-
83
- seed = settings['seed']
84
- if 'generator' not in self.global_state:
85
- self.global_state['generator'] = random.Random(seed)
86
- generator: random.Random = self.global_state['generator']
87
-
88
- if parameterwise:
89
- low, high = self.get_settings('low', 'high', params=params)
90
- lr = [generator.uniform(l, h) for l, h in zip(low, high)]
91
- else:
92
- low = settings['low']
93
- high = settings['high']
94
- lr = generator.uniform(low, high)
95
-
96
- torch._foreach_mul_(tensors, lr)
97
- return tensors