torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from typing import Literal
2
2
  from collections.abc import Callable
3
3
  import torch
4
4
 
5
- from ...core import Module, Target, Transform, Chainable, apply
5
+ from ...core import Module, Target, Transform, Chainable, apply_transform
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
8
 
@@ -47,27 +47,27 @@ class CurveBall(Module):
47
47
  if inner is not None: self.set_child('inner', inner)
48
48
 
49
49
  @torch.no_grad
50
- def step(self, vars):
50
+ def step(self, var):
51
51
 
52
- params = vars.params
52
+ params = var.params
53
53
  settings = self.settings[params[0]]
54
54
  hvp_method = settings['hvp_method']
55
55
  h = settings['h']
56
56
 
57
- precond_lr, momentum, reg = self.get_settings('momentum', 'decay_rate', 'reg', params=params, cls=NumberList)
57
+ precond_lr, momentum, reg = self.get_settings(params, 'precond_lr', 'momentum', 'reg', cls=NumberList)
58
58
 
59
59
 
60
- closure = vars.closure
60
+ closure = var.closure
61
61
  assert closure is not None
62
62
 
63
- z, Hz = self.get_state('z', 'Hz', params=params, cls=TensorList)
63
+ z, Hz = self.get_state(params, 'z', 'Hz', cls=TensorList)
64
64
 
65
65
  if hvp_method == 'autograd':
66
- grad = vars.get_grad(create_graph=True)
66
+ grad = var.get_grad(create_graph=True)
67
67
  Hvp = hvp(params, grad, z)
68
68
 
69
69
  elif hvp_method == 'forward':
70
- loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=vars.get_grad(), normalize=True)
70
+ loss, Hvp = hvp_fd_forward(closure, params, z, h=h, g_0=var.get_grad(), normalize=True)
71
71
 
72
72
  elif hvp_method == 'central':
73
73
  loss, Hvp = hvp_fd_central(closure, params, z, h=h, normalize=True)
@@ -79,11 +79,11 @@ class CurveBall(Module):
79
79
  Hz.set_(Hvp + z*reg)
80
80
 
81
81
 
82
- update = vars.get_update()
82
+ update = var.get_update()
83
83
  if 'inner' in self.children:
84
- update = apply(self.children['inner'], update, params, grads=vars.grad, vars=vars)
84
+ update = apply_transform(self.children['inner'], update, params, grads=var.grad, var=var)
85
85
 
86
86
  z = curveball(TensorList(update), z, Hz, momentum=momentum, precond_lr=precond_lr)
87
- vars.update = z.neg()
87
+ var.update = z.neg()
88
88
 
89
- return vars
89
+ return var
@@ -0,0 +1,225 @@
1
+ import itertools
2
+ import math
3
+ import warnings
4
+ from collections.abc import Callable
5
+ from contextlib import nullcontext
6
+ from functools import partial
7
+ from typing import Any, Literal
8
+
9
+ import numpy as np
10
+ import scipy.optimize
11
+ import torch
12
+
13
+ from ...core import Chainable, Module, apply_transform
14
+ from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
+ from ...utils.derivatives import (
16
+ hessian_list_to_mat,
17
+ jacobian_wrt,
18
+ hvp,
19
+ )
20
+
21
+ def _poly_eval_diag(s: np.ndarray, c, derivatives):
22
+ val = float(c) + (derivatives[0] * s).sum(-1)
23
+
24
+ if len(derivatives) > 1:
25
+ for i, d_diag in enumerate(derivatives[1:], 2):
26
+ val += (d_diag * (s**i)).sum(-1) / math.factorial(i)
27
+
28
+ return val
29
+
30
+ def _proximal_poly_v_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
31
+ """Computes the value of the proximal polynomial approximation."""
32
+ if x.ndim == 2: x = x.T
33
+ s = x - x0
34
+
35
+ val = _poly_eval_diag(s, c, derivatives)
36
+
37
+ penalty = 0
38
+ if prox != 0:
39
+ penalty = (prox / 2) * (s**2).sum(-1)
40
+
41
+ return val + penalty
42
+
43
+ def _proximal_poly_g_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
44
+ """Computes the gradient of the proximal polynomial approximation."""
45
+ s = x - x0
46
+
47
+ g = derivatives[0].copy()
48
+
49
+ if len(derivatives) > 1:
50
+ for i, d_diag in enumerate(derivatives[1:], 2):
51
+ g += d_diag * (s**(i - 1)) / math.factorial(i - 1)
52
+
53
+ if prox != 0:
54
+ g += prox * s
55
+
56
+ return g
57
+
58
+ def _proximal_poly_H_diag(x: np.ndarray, c, prox, x0: np.ndarray, derivatives):
59
+ """Computes the Hessian of the proximal polynomial approximation."""
60
+ s = x - x0
61
+ n = x.shape[0]
62
+
63
+ if len(derivatives) < 2:
64
+ H_diag = np.zeros(n, dtype=s.dtype)
65
+ else:
66
+ H_diag = derivatives[1].copy()
67
+
68
+ if len(derivatives) > 2:
69
+ for i, d_diag in enumerate(derivatives[2:], 3):
70
+ H_diag += d_diag * (s**(i - 2)) / math.factorial(i - 2)
71
+
72
+ if prox != 0:
73
+ H_diag += prox
74
+
75
+ return np.diag(H_diag)
76
+
77
+ def _poly_minimize(trust_region, prox, de_iters: Any, c, x: torch.Tensor, derivatives):
78
+ derivatives = [T.detach().cpu().numpy().astype(np.float64) for T in derivatives]
79
+ x0 = x.detach().cpu().numpy().astype(np.float64) # taylor series center
80
+ bounds = None
81
+ if trust_region is not None: bounds = list(zip(x0 - trust_region, x0 + trust_region))
82
+
83
+ # if len(derivatives) is 1, only gradient is available, I use that to test proximal penalty and bounds
84
+ if bounds is None:
85
+ if len(derivatives) == 1: method = 'bfgs'
86
+ else: method = 'trust-exact'
87
+ else:
88
+ if len(derivatives) == 1: method = 'l-bfgs-b'
89
+ else: method = 'trust-constr'
90
+
91
+ x_init = x0.copy()
92
+ v0 = _proximal_poly_v_diag(x0, c, prox, x0, derivatives)
93
+ if de_iters is not None and de_iters != 0:
94
+ if de_iters == -1: de_iters = None # let scipy decide
95
+ res = scipy.optimize.differential_evolution(
96
+ _proximal_poly_v_diag,
97
+ bounds if bounds is not None else list(zip(x0 - 10, x0 + 10)),
98
+ args=(c, prox, x0.copy(), derivatives),
99
+ maxiter=de_iters,
100
+ vectorized=True,
101
+ )
102
+ if res.fun < v0: x_init = res.x
103
+
104
+ res = scipy.optimize.minimize(
105
+ _proximal_poly_v_diag,
106
+ x_init,
107
+ method=method,
108
+ args=(c, prox, x0.copy(), derivatives),
109
+ jac=_proximal_poly_g_diag,
110
+ hess=_proximal_poly_H_diag,
111
+ bounds=bounds
112
+ )
113
+
114
+ return torch.from_numpy(res.x).to(x), res.fun
115
+
116
+
117
+
118
+ class DiagonalHigherOrderNewton(Module):
119
+ """
120
+ Hvp with ones doesn't give you the diagonal unless derivatives are diagonal, but somehow it still works,
121
+ except it doesn't work in all cases except ones where it works.
122
+ """
123
+ def __init__(
124
+ self,
125
+ order: int = 4,
126
+ trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
127
+ increase: float = 1.5,
128
+ decrease: float = 0.75,
129
+ trust_init: float | None = None,
130
+ trust_tol: float = 1,
131
+ de_iters: int | None = None,
132
+ vectorize: bool = True,
133
+ ):
134
+ if trust_init is None:
135
+ if trust_method == 'bounds': trust_init = 1
136
+ else: trust_init = 0.1
137
+
138
+ defaults = dict(order=order, trust_method=trust_method, increase=increase, decrease=decrease, trust_tol=trust_tol, trust_init=trust_init, vectorize=vectorize, de_iters=de_iters)
139
+ super().__init__(defaults)
140
+
141
+ @torch.no_grad
142
+ def step(self, var):
143
+ params = TensorList(var.params)
144
+ closure = var.closure
145
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
146
+
147
+ settings = self.settings[params[0]]
148
+ order = settings['order']
149
+ increase = settings['increase']
150
+ decrease = settings['decrease']
151
+ trust_tol = settings['trust_tol']
152
+ trust_init = settings['trust_init']
153
+ trust_method = settings['trust_method']
154
+ de_iters = settings['de_iters']
155
+
156
+ trust_value = self.global_state.get('trust_value', trust_init)
157
+
158
+
159
+ # ------------------------ calculate grad and hessian ------------------------ #
160
+ with torch.enable_grad():
161
+ loss = var.loss = var.loss_approx = closure(False)
162
+
163
+ g = torch.autograd.grad(loss, params, create_graph=True)
164
+ var.grad = list(g)
165
+
166
+ derivatives = [g]
167
+ T = g # current derivatives tensor diagonal
168
+ ones = [torch.ones_like(t) for t in g]
169
+
170
+ # get all derivatives up to order
171
+ for o in range(2, order + 1):
172
+ T = hvp(params, T, ones, create_graph=o != order)
173
+ derivatives.append(T)
174
+
175
+ x0 = torch.cat([p.ravel() for p in params])
176
+
177
+ if trust_method is None: trust_method = 'none'
178
+ else: trust_method = trust_method.lower()
179
+
180
+ if trust_method == 'none':
181
+ trust_region = None
182
+ prox = 0
183
+
184
+ elif trust_method == 'bounds':
185
+ trust_region = trust_value
186
+ prox = 0
187
+
188
+ elif trust_method == 'proximal':
189
+ trust_region = None
190
+ prox = 1 / trust_value
191
+
192
+ else:
193
+ raise ValueError(trust_method)
194
+
195
+ x_star, expected_loss = _poly_minimize(
196
+ trust_region=trust_region,
197
+ prox=prox,
198
+ de_iters=de_iters,
199
+ c=loss.item(),
200
+ x=x0,
201
+ derivatives=[torch.cat([t.ravel() for t in d]) for d in derivatives],
202
+ )
203
+
204
+ # trust region
205
+ if trust_method != 'none':
206
+ expected_reduction = loss - expected_loss
207
+
208
+ vec_to_tensors_(x_star, params)
209
+ loss_star = closure(False)
210
+ vec_to_tensors_(x0, params)
211
+ reduction = loss - loss_star
212
+
213
+ # failed step
214
+ if reduction <= 0:
215
+ x_star = x0
216
+ self.global_state['trust_value'] = trust_value * decrease
217
+
218
+ # very good step
219
+ elif expected_reduction / reduction <= trust_tol:
220
+ self.global_state['trust_value'] = trust_value * increase
221
+
222
+ difference = vec_to_tensors(x0 - x_star, params)
223
+ var.update = list(difference)
224
+ return var
225
+
@@ -0,0 +1,117 @@
1
+ from contextlib import nullcontext
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from functools import partial
5
+ import itertools
6
+ from typing import Literal
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, apply_transform
11
+ from ...utils import TensorList, vec_to_tensors
12
+ from ...utils.derivatives import (
13
+ hessian_list_to_mat,
14
+ jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
15
+ )
16
+
17
+ def _batched_dot(x, y):
18
+ return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
19
+
20
+ def _cosine_similarity(x, y):
21
+ denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
22
+ return _batched_dot(x, y) / denom
23
+
24
+ class EigenDescent(Module):
25
+ """
26
+ Uses eigenvectors corresponding to certain eigenvalues. Please note that this is experimental and isn't guaranteed to work.
27
+
28
+ Args:
29
+ mode (str, optional):
30
+ - largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
31
+ - smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
32
+ - mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
33
+ - mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
34
+ - mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
35
+ - mm - for testing.
36
+
37
+ Defaults to 'mean-sign'.
38
+ hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
39
+ vectorize (bool, optional): how to calculate hessian. Defaults to True.
40
+
41
+ """
42
+ def __init__(
43
+ self,
44
+ mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
45
+ hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
46
+ vectorize: bool = True,
47
+ ):
48
+ defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
49
+ super().__init__(defaults)
50
+
51
+ @torch.no_grad
52
+ def step(self, var):
53
+ params = TensorList(var.params)
54
+ closure = var.closure
55
+ if closure is None: raise RuntimeError('NewtonCG requires closure')
56
+
57
+ settings = self.settings[params[0]]
58
+ mode = settings['mode']
59
+ hessian_method = settings['hessian_method']
60
+ vectorize = settings['vectorize']
61
+
62
+ # ------------------------ calculate grad and hessian ------------------------ #
63
+ if hessian_method == 'autograd':
64
+ with torch.enable_grad():
65
+ loss = var.loss = var.loss_approx = closure(False)
66
+ g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
67
+ g_list = [t[0] for t in g_list] # remove leading dim from loss
68
+ var.grad = g_list
69
+ H = hessian_list_to_mat(H_list)
70
+
71
+ elif hessian_method in ('func', 'autograd.functional'):
72
+ strat = 'forward-mode' if vectorize else 'reverse-mode'
73
+ with torch.enable_grad():
74
+ g_list = var.get_grad(retain_graph=True)
75
+ H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
76
+ method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
77
+
78
+ else:
79
+ raise ValueError(hessian_method)
80
+
81
+
82
+ # ----------------------------------- solve ---------------------------------- #
83
+ g = torch.cat([t.ravel() for t in g_list])
84
+ L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
85
+ if mode == 'largest':
86
+ # smallest eigenvalue if all eigenvalues are negative else largest
87
+ if L[-1] <= 0: d = Q[0]
88
+ else: d = Q[-1]
89
+
90
+ elif mode == 'smallest':
91
+ # smallest eigenvalue if negative eigenvalues exist else largest
92
+ if L[0] <= 0: d = Q[0]
93
+ else: d = Q[-1]
94
+
95
+ elif mode == 'magnitude':
96
+ # largest by magnitude
97
+ if L[0].abs() > L[-1].abs(): d = Q[0]
98
+ else: d = Q[-1]
99
+
100
+ elif mode == 'mean-dot':
101
+ d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
102
+
103
+ elif mode == 'mean-sign':
104
+ d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
105
+
106
+ elif mode == 'mean-cosine':
107
+ d = (Q * _cosine_similarity(Q, g)).mean(1)
108
+
109
+ elif mode == 'mm':
110
+ d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
111
+
112
+ else:
113
+ raise ValueError(mode)
114
+
115
+ var.update = vec_to_tensors(g.dot(d).sign() * d, params)
116
+ return var
117
+
@@ -0,0 +1,172 @@
1
+ from typing import cast
2
+ import warnings
3
+
4
+ import torch
5
+
6
+ from ...core import Module
7
+ from ...utils import vec_to_tensors, vec_to_tensors_
8
+
9
+
10
+ class ExponentialTrajectoryFit(Module):
11
+ """A method. Please note that this is experimental and isn't guaranteed to work."""
12
+ def __init__(self, step_size=1e-3):
13
+ defaults = dict(step_size = step_size)
14
+ super().__init__(defaults)
15
+
16
+ @torch.no_grad
17
+ def step(self, var):
18
+ closure = var.closure
19
+ assert closure is not None
20
+ step_size = self.settings[var.params[0]]['step_size']
21
+
22
+ # 1. perform 3 GD steps to obtain 4 points
23
+ points = [torch.cat([p.view(-1) for p in var.params])]
24
+ for i in range(3):
25
+ if i == 0: grad = var.get_grad()
26
+ else:
27
+ with torch.enable_grad(): closure()
28
+ grad = [cast(torch.Tensor, p.grad) for p in var.params]
29
+
30
+ # GD step
31
+ torch._foreach_sub_(var.params, grad, alpha=step_size)
32
+
33
+ points.append(torch.cat([p.view(-1) for p in var.params]))
34
+
35
+ assert len(points) == 4, len(points)
36
+ x0, x1, x2, x3 = points
37
+ dim = x0.numel()
38
+
39
+ # 2. fit a generalized exponential curve
40
+ d0 = (x1 - x0).unsqueeze(1) # column vectors
41
+ d1 = (x2 - x1).unsqueeze(1)
42
+ d2 = (x3 - x2).unsqueeze(1)
43
+
44
+ # cat
45
+ D1 = torch.cat([d0, d1], dim=1)
46
+ D2 = torch.cat([d1, d2], dim=1)
47
+
48
+ # if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
49
+ if x0.numel() >= 2:
50
+ if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
51
+ pass # need to put a quadratic fit there
52
+
53
+ M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
54
+
55
+ # now we can predict x*
56
+ I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
57
+ B = I - M
58
+ z = x1 - M @ x0
59
+
60
+ x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
61
+
62
+ vec_to_tensors_(x0, var.params)
63
+ difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
64
+ var.update = list(difference)
65
+ return var
66
+
67
+
68
+
69
+ class ExponentialTrajectoryFitV2(Module):
70
+ """Should be better than one above, except it isn't. Please note that this is experimental and isn't guaranteed to work."""
71
+ def __init__(self, step_size=1e-3, num_steps: int= 4):
72
+ defaults = dict(step_size = step_size, num_steps=num_steps)
73
+ super().__init__(defaults)
74
+
75
+ @torch.no_grad
76
+ def step(self, var):
77
+ closure = var.closure
78
+ assert closure is not None
79
+ step_size = self.settings[var.params[0]]['step_size']
80
+ num_steps = self.settings[var.params[0]]['num_steps']
81
+
82
+ # 1. perform 3 GD steps to obtain 4 points (or more)
83
+ grad = var.get_grad()
84
+ points = [torch.cat([p.view(-1) for p in var.params])]
85
+ point_grads = [torch.cat([g.view(-1) for g in grad])]
86
+
87
+ for i in range(num_steps):
88
+ # GD step
89
+ torch._foreach_sub_(var.params, grad, alpha=step_size)
90
+
91
+ points.append(torch.cat([p.view(-1) for p in var.params]))
92
+
93
+ closure(backward=True)
94
+ grad = [cast(torch.Tensor, p.grad) for p in var.params]
95
+ point_grads.append(torch.cat([g.view(-1) for g in grad]))
96
+
97
+
98
+ X = torch.stack(points, 1) # dim, num_steps+1
99
+ G = torch.stack(point_grads, 1)
100
+ dim = points[0].numel()
101
+
102
+ X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
103
+
104
+ P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
105
+ A = P[:, :dim]
106
+ b = -P[:, dim]
107
+
108
+ # symmetrize
109
+ A = 0.5 * (A + A.T)
110
+
111
+ # predict x*
112
+ x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
113
+
114
+ vec_to_tensors_(points[0], var.params)
115
+ difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
116
+ var.update = list(difference)
117
+ return var
118
+
119
+
120
+
121
+
122
+ def _fit_exponential(y0, y1, y2):
123
+ """x0, x1 and x2 are assumed to be 0, 1, 2"""
124
+ r = (y2 - y1) / (y1 - y0)
125
+ ones = r==1
126
+ r[ones] = 0
127
+ B = (y1 - y0) / (r - 1)
128
+ A = y0 - B
129
+
130
+ A[ones] = 0
131
+ B[ones] = 0
132
+ return A, B, r
133
+
134
+ class PointwiseExponential(Module):
135
+ """A stupid method (for my youtube channel). Please note that this is experimental and isn't guaranteed to work."""
136
+ def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
137
+ defaults = dict(reg=reg, steps=steps, step_size=step_size)
138
+ super().__init__(defaults)
139
+
140
+ @torch.no_grad
141
+ def step(self, var):
142
+ closure = var.closure
143
+ assert closure is not None
144
+ settings = self.settings[var.params[0]]
145
+ step_size = settings['step_size']
146
+ reg = settings['reg']
147
+ steps = settings['steps']
148
+
149
+ # 1. perform 2 GD steps to obtain 3 points
150
+ points = [torch.cat([p.view(-1) for p in var.params])]
151
+ for i in range(2):
152
+ if i == 0: grad = var.get_grad()
153
+ else:
154
+ with torch.enable_grad(): closure()
155
+ grad = [cast(torch.Tensor, p.grad) for p in var.params]
156
+
157
+ # GD step
158
+ torch._foreach_sub_(var.params, grad, alpha=step_size)
159
+
160
+ points.append(torch.cat([p.view(-1) for p in var.params]))
161
+
162
+ assert len(points) == 3, len(points)
163
+ y0, y1, y2 = points
164
+
165
+ A, B, r = _fit_exponential(y0, y1, y2)
166
+ r = r.clip(max = 1-reg)
167
+ x_star = A + B * r**steps
168
+
169
+ vec_to_tensors_(y0, var.params)
170
+ difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
171
+ var.update = list(difference)
172
+ return var
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Module, Vars
8
+ from ...core import Module, Var
9
9
  from ...utils import NumberList, TensorList
10
10
  from ...utils.derivatives import jacobian_wrt
11
11
  from ..grad_approximation import GradApproximator, GradTarget
@@ -42,7 +42,7 @@ class GradMin(Reformulation):
42
42
  super().__init__(defaults)
43
43
 
44
44
  @torch.no_grad
45
- def closure(self, backward, closure, params, vars):
45
+ def closure(self, backward, closure, params, var):
46
46
  settings = self.settings[params[0]]
47
47
  loss_term = settings['loss_term']
48
48
  relative = settings['relative']
@@ -3,13 +3,13 @@ from typing import Any, Literal, overload
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, apply, Modular
6
+ from ...core import Chainable, Module, apply_transform, Modular
7
7
  from ...utils import TensorList, as_tensorlist
8
8
  from ...utils.derivatives import hvp
9
9
  from ..quasi_newton import LBFGS
10
10
 
11
11
  class NewtonSolver(Module):
12
- """Matrix free newton via with any custom solver (usually it is better to just use NewtonCG or NystromPCG is even better)"""
12
+ """Matrix free newton via with any custom solver (this is for testing, use NewtonCG or NystromPCG)"""
13
13
  def __init__(
14
14
  self,
15
15
  solver: Callable[[list[torch.Tensor]], Any] = lambda p: Modular(p, LBFGS()),
@@ -26,9 +26,9 @@ class NewtonSolver(Module):
26
26
  self.set_child('inner', inner)
27
27
 
28
28
  @torch.no_grad
29
- def step(self, vars):
30
- params = TensorList(vars.params)
31
- closure = vars.closure
29
+ def step(self, var):
30
+ params = TensorList(var.params)
31
+ closure = var.closure
32
32
  if closure is None: raise RuntimeError('NewtonCG requires closure')
33
33
 
34
34
  settings = self.settings[params[0]]
@@ -39,7 +39,7 @@ class NewtonSolver(Module):
39
39
  warm_start = settings['warm_start']
40
40
 
41
41
  # ---------------------- Hessian vector product function --------------------- #
42
- grad = vars.get_grad(create_graph=True)
42
+ grad = var.get_grad(create_graph=True)
43
43
 
44
44
  def H_mm(x):
45
45
  with torch.enable_grad():
@@ -50,11 +50,11 @@ class NewtonSolver(Module):
50
50
  # -------------------------------- inner step -------------------------------- #
51
51
  b = as_tensorlist(grad)
52
52
  if 'inner' in self.children:
53
- b = as_tensorlist(apply(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, vars=vars))
53
+ b = as_tensorlist(apply_transform(self.children['inner'], [g.clone() for g in grad], params=params, grads=grad, var=var))
54
54
 
55
55
  # ---------------------------------- run cg ---------------------------------- #
56
56
  x0 = None
57
- if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
57
+ if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
58
58
  if x0 is None: x = b.zeros_like().requires_grad_(True)
59
59
  else: x = x0.clone().requires_grad_(True)
60
60
 
@@ -76,13 +76,13 @@ class NewtonSolver(Module):
76
76
  assert loss is not None
77
77
  if min(loss, loss/initial_loss) < tol: break
78
78
 
79
- print(f'{loss = }')
79
+ # print(f'{loss = }')
80
80
 
81
81
  if warm_start:
82
82
  assert x0 is not None
83
83
  x0.copy_(x)
84
84
 
85
- vars.update = x.detach()
86
- return vars
85
+ var.update = x.detach()
86
+ return var
87
87
 
88
88