torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,120 +0,0 @@
1
- from contextlib import nullcontext
2
- import warnings
3
- from collections.abc import Callable
4
- from functools import partial
5
- import itertools
6
- from typing import Literal
7
-
8
- import torch
9
-
10
- from ...core import Chainable, Module, apply_transform
11
- from ...utils import TensorList, vec_to_tensors
12
- from ...utils.derivatives import (
13
- hessian_list_to_mat,
14
- jacobian_wrt, jacobian_and_hessian_wrt, hessian_mat,
15
- )
16
-
17
- def _batched_dot(x, y):
18
- return (x.unsqueeze(-2) @ y.unsqueeze(-1)).squeeze(-1).squeeze(-1)
19
-
20
- def _cosine_similarity(x, y):
21
- denom = torch.linalg.vector_norm(x, dim=-1) * torch.linalg.vector_norm(y, dim=-1).clip(min=torch.finfo(x.dtype).eps) # pylint:disable=not-callable
22
- return _batched_dot(x, y) / denom
23
-
24
- class EigenDescent(Module):
25
- """
26
- Uses eigenvectors corresponding to certain eigenvalues. For now they are just extracted from hessian.
27
-
28
- .. warning::
29
- Experimental.
30
-
31
- Args:
32
- mode (str, optional):
33
- - largest - use largest eigenvalue unless all eigenvalues are negative, then smallest is used.
34
- - smallest - use smallest eigenvalue unless all eigenvalues are positive, then largest is used.
35
- - mean-sign - use mean of eigenvectors multiplied by 1 or -1 if they point in opposite direction from gradient.
36
- - mean-dot - use mean of eigenvectors multiplied by dot product with gradient.
37
- - mean-cosine - use mean of eigenvectors multiplied by cosine similarity with gradient.
38
- - mm - for testing.
39
-
40
- Defaults to 'mean-sign'.
41
- hessian_method (str, optional): how to calculate hessian. Defaults to "autograd".
42
- vectorize (bool, optional): how to calculate hessian. Defaults to True.
43
-
44
- """
45
- def __init__(
46
- self,
47
- mode: Literal['largest', 'smallest','magnitude', 'mean-sign', 'mean-dot', 'mean-cosine', 'mm'] = 'mean-sign',
48
- hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
49
- vectorize: bool = True,
50
- ):
51
- defaults = dict(hessian_method=hessian_method, vectorize=vectorize, mode=mode)
52
- super().__init__(defaults)
53
-
54
- @torch.no_grad
55
- def step(self, var):
56
- params = TensorList(var.params)
57
- closure = var.closure
58
- if closure is None: raise RuntimeError('NewtonCG requires closure')
59
-
60
- settings = self.settings[params[0]]
61
- mode = settings['mode']
62
- hessian_method = settings['hessian_method']
63
- vectorize = settings['vectorize']
64
-
65
- # ------------------------ calculate grad and hessian ------------------------ #
66
- if hessian_method == 'autograd':
67
- with torch.enable_grad():
68
- loss = var.loss = var.loss_approx = closure(False)
69
- g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
70
- g_list = [t[0] for t in g_list] # remove leading dim from loss
71
- var.grad = g_list
72
- H = hessian_list_to_mat(H_list)
73
-
74
- elif hessian_method in ('func', 'autograd.functional'):
75
- strat = 'forward-mode' if vectorize else 'reverse-mode'
76
- with torch.enable_grad():
77
- g_list = var.get_grad(retain_graph=True)
78
- H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
79
- method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
80
-
81
- else:
82
- raise ValueError(hessian_method)
83
-
84
-
85
- # ----------------------------------- solve ---------------------------------- #
86
- g = torch.cat([t.ravel() for t in g_list])
87
- L, Q = torch.linalg.eigh(H) # L is sorted # pylint:disable=not-callable
88
- if mode == 'largest':
89
- # smallest eigenvalue if all eigenvalues are negative else largest
90
- if L[-1] <= 0: d = Q[0]
91
- else: d = Q[-1]
92
-
93
- elif mode == 'smallest':
94
- # smallest eigenvalue if negative eigenvalues exist else largest
95
- if L[0] <= 0: d = Q[0]
96
- else: d = Q[-1]
97
-
98
- elif mode == 'magnitude':
99
- # largest by magnitude
100
- if L[0].abs() > L[-1].abs(): d = Q[0]
101
- else: d = Q[-1]
102
-
103
- elif mode == 'mean-dot':
104
- d = ((g.unsqueeze(0) @ Q).squeeze(0) * Q).mean(1)
105
-
106
- elif mode == 'mean-sign':
107
- d = ((g.unsqueeze(0) @ Q).squeeze(0).sign() * Q).mean(1)
108
-
109
- elif mode == 'mean-cosine':
110
- d = (Q * _cosine_similarity(Q, g)).mean(1)
111
-
112
- elif mode == 'mm':
113
- d = (g.unsqueeze(0) @ Q).squeeze(0) / g.numel()
114
-
115
- else:
116
- raise ValueError(mode)
117
-
118
- var.update = vec_to_tensors(g.dot(d).sign() * d, params)
119
- return var
120
-
@@ -1,195 +0,0 @@
1
- from typing import cast
2
- import warnings
3
-
4
- import torch
5
-
6
- from ...core import Module
7
- from ...utils import vec_to_tensors, vec_to_tensors_, as_tensorlist
8
-
9
-
10
- class ExponentialTrajectoryFit(Module):
11
- """A method.
12
-
13
- .. warning::
14
- Experimental.
15
- """
16
- def __init__(self, step_size=1e-2, adaptive:bool=True):
17
- defaults = dict(step_size = step_size,adaptive=adaptive)
18
- super().__init__(defaults)
19
-
20
- @torch.no_grad
21
- def step(self, var):
22
- closure = var.closure
23
- assert closure is not None
24
- step_size = self.settings[var.params[0]]['step_size']
25
- adaptive = self.settings[var.params[0]]['adaptive']
26
-
27
-
28
- # 1. perform 3 GD steps to obtain 4 points
29
- points = [torch.cat([p.view(-1) for p in var.params])]
30
- for i in range(3):
31
- if i == 0:
32
- grad = var.get_grad()
33
- if adaptive:
34
- step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
35
-
36
- else:
37
- with torch.enable_grad(): closure()
38
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
39
-
40
- # GD step
41
- torch._foreach_sub_(var.params, grad, alpha=step_size)
42
-
43
- points.append(torch.cat([p.view(-1) for p in var.params]))
44
-
45
- assert len(points) == 4, len(points)
46
- x0, x1, x2, x3 = points
47
- dim = x0.numel()
48
-
49
- # 2. fit a generalized exponential curve
50
- d0 = (x1 - x0).unsqueeze(1) # column vectors
51
- d1 = (x2 - x1).unsqueeze(1)
52
- d2 = (x3 - x2).unsqueeze(1)
53
-
54
- # cat
55
- D1 = torch.cat([d0, d1], dim=1)
56
- D2 = torch.cat([d1, d2], dim=1)
57
-
58
- # if points are collinear this will happen on sphere and a quadratic "line search" will minimize it
59
- if x0.numel() >= 2:
60
- if torch.linalg.matrix_rank(D1) < 2: # pylint:disable=not-callable
61
- pass # need to put a quadratic fit there
62
-
63
- M = D2 @ torch.linalg.pinv(D1) # pylint:disable=not-callable # this defines the curve
64
-
65
- # now we can predict x*
66
- I = torch.eye(dim, device=x0.device, dtype=x0.dtype)
67
- B = I - M
68
- z = x1 - M @ x0
69
-
70
- x_star = torch.linalg.lstsq(B, z).solution # pylint:disable=not-callable
71
-
72
- vec_to_tensors_(x0, var.params)
73
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
74
- var.update = list(difference)
75
- return var
76
-
77
-
78
-
79
- class ExponentialTrajectoryFitV2(Module):
80
- """Should be better than one above, except it isn't.
81
-
82
- .. warning::
83
- Experimental.
84
-
85
- """
86
- def __init__(self, step_size=1e-3, num_steps: int= 4, adaptive:bool=True):
87
- defaults = dict(step_size = step_size, num_steps=num_steps, adaptive=adaptive)
88
- super().__init__(defaults)
89
-
90
- @torch.no_grad
91
- def step(self, var):
92
- closure = var.closure
93
- assert closure is not None
94
- step_size = self.settings[var.params[0]]['step_size']
95
- num_steps = self.settings[var.params[0]]['num_steps']
96
- adaptive = self.settings[var.params[0]]['adaptive']
97
-
98
- # 1. perform 3 GD steps to obtain 4 points (or more)
99
- grad = var.get_grad()
100
- if adaptive:
101
- step_size /= as_tensorlist(grad).abs().global_mean().clip(min=1e-4)
102
-
103
- points = [torch.cat([p.view(-1) for p in var.params])]
104
- point_grads = [torch.cat([g.view(-1) for g in grad])]
105
-
106
- for i in range(num_steps):
107
- # GD step
108
- torch._foreach_sub_(var.params, grad, alpha=step_size)
109
-
110
- points.append(torch.cat([p.view(-1) for p in var.params]))
111
-
112
- closure(backward=True)
113
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
114
- point_grads.append(torch.cat([g.view(-1) for g in grad]))
115
-
116
-
117
- X = torch.stack(points, 1) # dim, num_steps+1
118
- G = torch.stack(point_grads, 1)
119
- dim = points[0].numel()
120
-
121
- X = torch.cat([X, torch.ones(1, num_steps+1, dtype=G.dtype, device=G.device)])
122
-
123
- P = G @ torch.linalg.pinv(X) # pylint:disable=not-callable
124
- A = P[:, :dim]
125
- b = -P[:, dim]
126
-
127
- # symmetrize
128
- A = 0.5 * (A + A.T)
129
-
130
- # predict x*
131
- x_star = torch.linalg.lstsq(A, b).solution # pylint:disable=not-callable
132
-
133
- vec_to_tensors_(points[0], var.params)
134
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
135
- var.update = list(difference)
136
- return var
137
-
138
-
139
-
140
-
141
- def _fit_exponential(y0, y1, y2):
142
- """x0, x1 and x2 are assumed to be 0, 1, 2"""
143
- r = (y2 - y1) / (y1 - y0)
144
- ones = r==1
145
- r[ones] = 0
146
- B = (y1 - y0) / (r - 1)
147
- A = y0 - B
148
-
149
- A[ones] = 0
150
- B[ones] = 0
151
- return A, B, r
152
-
153
- class PointwiseExponential(Module):
154
- """A stupid method (for my youtube channel).
155
-
156
- .. warning::
157
- Experimental.
158
- """
159
- def __init__(self, step_size: float = 1e-3, reg: float = 1e-2, steps = 10000):
160
- defaults = dict(reg=reg, steps=steps, step_size=step_size)
161
- super().__init__(defaults)
162
-
163
- @torch.no_grad
164
- def step(self, var):
165
- closure = var.closure
166
- assert closure is not None
167
- settings = self.settings[var.params[0]]
168
- step_size = settings['step_size']
169
- reg = settings['reg']
170
- steps = settings['steps']
171
-
172
- # 1. perform 2 GD steps to obtain 3 points
173
- points = [torch.cat([p.view(-1) for p in var.params])]
174
- for i in range(2):
175
- if i == 0: grad = var.get_grad()
176
- else:
177
- with torch.enable_grad(): closure()
178
- grad = [cast(torch.Tensor, p.grad) for p in var.params]
179
-
180
- # GD step
181
- torch._foreach_sub_(var.params, grad, alpha=step_size)
182
-
183
- points.append(torch.cat([p.view(-1) for p in var.params]))
184
-
185
- assert len(points) == 3, len(points)
186
- y0, y1, y2 = points
187
-
188
- A, B, r = _fit_exponential(y0, y1, y2)
189
- r = r.clip(max = 1-reg)
190
- x_star = A + B * r**steps
191
-
192
- vec_to_tensors_(y0, var.params)
193
- difference = torch._foreach_sub(var.params, vec_to_tensors(x_star, var.params))
194
- var.update = list(difference)
195
- return var
@@ -1,113 +0,0 @@
1
- from operator import itemgetter
2
- from functools import partial
3
- import math
4
- import torch
5
-
6
- from ...core import Module, Target, Transform, apply_transform, Chainable
7
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
-
17
-
18
- def exp_adam_(
19
- tensors: TensorList,
20
- exp_avg_: TensorList,
21
- exp_avg_exp_: TensorList,
22
- alpha: float | NumberList,
23
- beta1: float | NumberList,
24
- beta2: float | NumberList,
25
- eps: float | NumberList,
26
- step: int,
27
- pow: float = 2,
28
- debiased: bool = True,
29
- max_exp_avg_exp_: TensorList | None = None,
30
-
31
- # inner args
32
- inner: Module | None = None,
33
- params: list[torch.Tensor] | None = None,
34
- grads: list[torch.Tensor] | None = None,
35
- ):
36
- """Returns new tensors."""
37
- tensors_exp = tensors.abs().clip_(max=math.log(torch.finfo(tensors[0].dtype).max) / 2).exp_()
38
- exp_avg_exp_.lerp_(tensors_exp, 1-beta2)
39
-
40
- if max_exp_avg_exp_ is not None:
41
- max_exp_avg_exp_.maximum_(exp_avg_exp_)
42
- exp_avg_exp_ = max_exp_avg_exp_
43
-
44
- if inner is not None:
45
- assert params is not None
46
- tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
47
-
48
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
49
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
50
- return (exp_avg_.lazy_mul(alpha) / exp_avg_exp_.log().add_(eps))
51
-
52
- class ExpAdam(Transform):
53
- """Adam but uses abs exp and log instead of square and sqrt.
54
- The gradient will be clipped to half the maximum value representable by its dtype (around 50 for float32)
55
-
56
- Args:
57
- beta1 (float, optional): momentum. Defaults to 0.9.
58
- beta2 (float, optional): second momentum. Defaults to 0.999.
59
- eps (float, optional): epsilon. Defaults to 1e-8.
60
- alpha (float, optional): learning rate. Defaults to 1.
61
- amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
62
- pow (float, optional): power used in second momentum power and root. Defaults to 2.
63
- debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
64
- """
65
- def __init__(
66
- self,
67
- beta1: float = 0.9,
68
- beta2: float = 0.999,
69
- eps: float = 1e-8,
70
- amsgrad: bool = False,
71
- alpha: float = 1.,
72
- pow: float = 2,
73
- debiased: bool = True,
74
- inner: Chainable | None = None
75
- ):
76
- defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
77
- super().__init__(defaults, uses_grad=False)
78
-
79
- if inner is not None: self.set_child('inner', inner)
80
-
81
- @torch.no_grad
82
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
83
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
84
-
85
- beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
86
- amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
87
-
88
- if amsgrad:
89
- exp_avg, exp_avg_exp, max_exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', 'max_exp_avg_exp', cls=TensorList)
90
- else:
91
- exp_avg, exp_avg_exp = unpack_states(states, tensors, 'exp_avg', 'exp_avg_exp', cls=TensorList)
92
- max_exp_avg_exp = None
93
-
94
-
95
- return exp_adam_(
96
- tensors=TensorList(tensors),
97
- exp_avg_=exp_avg,
98
- exp_avg_exp_=exp_avg_exp,
99
- alpha=alpha,
100
- beta1=beta1,
101
- beta2=beta2,
102
- eps=eps,
103
- step=step,
104
- pow=pow,
105
- debiased=debiased,
106
- max_exp_avg_exp_=max_exp_avg_exp,
107
-
108
- # inner args
109
- inner=self.children.get("inner", None),
110
- params=params,
111
- grads=grads,
112
-
113
- )
@@ -1,141 +0,0 @@
1
- from collections import deque
2
- from operator import itemgetter
3
- import torch
4
-
5
- from ...core import Transform, Chainable, Module, Var, apply_transform
6
- from ...utils import TensorList, as_tensorlist, NumberList
7
- from ...modules.quasi_newton.lbfgs import _adaptive_damping, lbfgs, _lerp_params_update_
8
-
9
- class ExpandedLBFGS(Module):
10
- """L-BFGS but uses differences between more pairs than just consequtive. Window size controls how far away the pairs are allowed to be.
11
- """
12
- def __init__(
13
- self,
14
- history_size=10,
15
- window_size:int=3,
16
- tol: float | None = 1e-10,
17
- damping: bool = False,
18
- init_damping=0.9,
19
- eigval_bounds=(0.5, 50),
20
- params_beta: float | None = None,
21
- grads_beta: float | None = None,
22
- update_freq = 1,
23
- z_beta: float | None = None,
24
- tol_reset: bool = False,
25
- inner: Chainable | None = None,
26
- ):
27
- defaults = dict(history_size=history_size, window_size=window_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
28
- super().__init__(defaults)
29
-
30
- self.global_state['s_history'] = deque(maxlen=history_size)
31
- self.global_state['y_history'] = deque(maxlen=history_size)
32
- self.global_state['sy_history'] = deque(maxlen=history_size)
33
- self.global_state['p_history'] = deque(maxlen=window_size)
34
- self.global_state['g_history'] = deque(maxlen=window_size)
35
-
36
- if inner is not None:
37
- self.set_child('inner', inner)
38
-
39
- def reset(self):
40
- self.state.clear()
41
- self.global_state['step'] = 0
42
- self.global_state['s_history'].clear()
43
- self.global_state['y_history'].clear()
44
- self.global_state['sy_history'].clear()
45
- self.global_state['p_history'].clear()
46
- self.global_state['g_history'].clear()
47
-
48
- @torch.no_grad
49
- def step(self, var):
50
- params = as_tensorlist(var.params)
51
- update = as_tensorlist(var.get_update())
52
- step = self.global_state.get('step', 0)
53
- self.global_state['step'] = step + 1
54
-
55
- # history of s and k
56
- s_history: deque[TensorList] = self.global_state['s_history']
57
- y_history: deque[TensorList] = self.global_state['y_history']
58
- sy_history: deque[torch.Tensor] = self.global_state['sy_history']
59
- p_history: deque[TensorList] = self.global_state['p_history']
60
- g_history: deque[TensorList] = self.global_state['g_history']
61
-
62
- tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
63
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
64
- params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
65
-
66
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
67
- prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
68
-
69
- # 1st step - there are no previous params and grads, lbfgs will do normalized GD step
70
- if step == 0:
71
- s = None; y = None; ys = None
72
- else:
73
- s = l_params - prev_l_params
74
- y = l_update - prev_l_grad
75
- ys = s.dot(y)
76
-
77
- if damping:
78
- s, y, ys = _adaptive_damping(s, y, ys, init_damping=init_damping, eigval_bounds=eigval_bounds)
79
-
80
- prev_l_params.copy_(l_params)
81
- prev_l_grad.copy_(l_update)
82
-
83
- # update effective preconditioning state
84
- if step % update_freq == 0:
85
- if ys is not None and ys > 1e-10:
86
- assert s is not None and y is not None
87
- s_history.append(s)
88
- y_history.append(y)
89
- sy_history.append(ys)
90
-
91
- if len(p_history) > 1:
92
- for p_i, g_i in zip(list(p_history)[:-1], list(g_history)[:-1]):
93
- s_i = l_params - p_i
94
- y_i = l_update - g_i
95
- ys_i = s_i.dot(y_i)
96
-
97
- if ys_i > 1e-10:
98
- if damping:
99
- s_i, y_i, ys_i = _adaptive_damping(s_i, y_i, ys_i, init_damping=init_damping, eigval_bounds=eigval_bounds)
100
-
101
- s_history.append(s_i)
102
- y_history.append(y_i)
103
- sy_history.append(ys_i)
104
-
105
- p_history.append(l_params.clone())
106
- g_history.append(l_update.clone())
107
-
108
-
109
- # step with inner module before applying preconditioner
110
- if self.children:
111
- update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
112
-
113
- # tolerance on gradient difference to avoid exploding after converging
114
- if tol is not None:
115
- if y is not None and y.abs().global_max() <= tol:
116
- var.update = update # may have been updated by inner module, probably makes sense to use it here?
117
- if tol_reset: self.reset()
118
- return var
119
-
120
- # lerp initial H^-1 @ q guess
121
- z_ema = None
122
- if z_beta is not None:
123
- z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
124
-
125
- # precondition
126
- dir = lbfgs(
127
- tensors_=as_tensorlist(update),
128
- s_history=s_history,
129
- y_history=y_history,
130
- sy_history=sy_history,
131
- y=y,
132
- sy=ys,
133
- z_beta = z_beta,
134
- z_ema = z_ema,
135
- step=step
136
- )
137
-
138
- var.update = dir
139
-
140
- return var
141
-
@@ -1,85 +0,0 @@
1
- from collections import deque
2
-
3
- import torch
4
-
5
- from ...core import TensorwiseTransform
6
-
7
-
8
- def eigh_solve(H: torch.Tensor, g: torch.Tensor):
9
- try:
10
- L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
11
- return Q @ ((Q.mH @ g) / L)
12
- except torch.linalg.LinAlgError:
13
- return None
14
-
15
-
16
- class HNewton(TensorwiseTransform):
17
- """This treats gradient differences as Hvps with vectors being parameter differences, using past gradients that are close to each other. Basically this is another limited memory quasi newton method to test.
18
-
19
- .. warning::
20
- Experimental.
21
-
22
- """
23
- def __init__(self, history_size: int, window_size: int, reg: float=0, tol: float = 1e-8, concat_params:bool=True, inner=None):
24
- defaults = dict(history_size=history_size, window_size=window_size, reg=reg, tol=tol)
25
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, inner=inner)
26
-
27
- def update_tensor(self, tensor, param, grad, loss, state, setting):
28
-
29
- history_size = setting['history_size']
30
-
31
- if 'param_history' not in state:
32
- state['param_history'] = deque(maxlen=history_size)
33
- state['grad_history'] = deque(maxlen=history_size)
34
-
35
- param_history: deque = state['param_history']
36
- grad_history: deque = state['grad_history']
37
- param_history.append(param.ravel())
38
- grad_history.append(tensor.ravel())
39
-
40
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
41
- window_size = setting['window_size']
42
- reg = setting['reg']
43
- tol = setting['tol']
44
-
45
- param_history: deque = state['param_history']
46
- grad_history: deque = state['grad_history']
47
- g = tensor.ravel()
48
-
49
- n = len(param_history)
50
- s_list = []
51
- y_list = []
52
-
53
- for i in range(n):
54
- for j in range(i):
55
- if i - j <= window_size:
56
- p_i, g_i = param_history[i], grad_history[i]
57
- p_j, g_j = param_history[j], grad_history[j]
58
- s = p_i - p_j # vec in hvp
59
- y = g_i - g_j # hvp
60
- if s.dot(y) > tol:
61
- s_list.append(s)
62
- y_list.append(y)
63
-
64
- if len(s_list) < 1:
65
- scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
66
- tensor.mul_(scale)
67
- return tensor
68
-
69
- S = torch.stack(s_list, 1)
70
- Y = torch.stack(y_list, 1)
71
-
72
- B = S.T @ Y
73
- if reg != 0: B.add_(torch.eye(B.size(0), device=B.device, dtype=B.dtype).mul_(reg))
74
- g_proj = g @ S
75
-
76
- newton_proj, info = torch.linalg.solve_ex(B, g_proj) # pylint:disable=not-callable
77
- if info != 0:
78
- newton_proj = -torch.linalg.lstsq(B, g_proj).solution # pylint:disable=not-callable
79
- newton = S @ newton_proj
80
- return newton.view_as(tensor)
81
-
82
-
83
- # scale = (1 / tensor.abs().sum()).clip(min=torch.finfo(tensor.dtype).eps, max=1)
84
- # tensor.mul_(scale)
85
- # return tensor