torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,233 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- from .line_search import LineSearchBase
5
-
6
-
7
- # polynomial interpolation
8
- # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
- # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
- def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
- """
12
- Gives the minimizer and minimum of the interpolating polynomial over given points
13
- based on function and derivative information. Defaults to bisection if no critical
14
- points are valid.
15
-
16
- Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
- modifications.
18
-
19
- Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
- Last edited 12/6/18.
21
-
22
- Inputs:
23
- points (nparray): two-dimensional array with each point of form [x f g]
24
- x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
- x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
- plot (bool): plot interpolating polynomial
27
-
28
- Outputs:
29
- x_sol (float): minimizer of interpolating polynomial
30
- F_min (float): minimum of interpolating polynomial
31
-
32
- Note:
33
- . Set f or g to np.nan if they are unknown
34
-
35
- """
36
- no_points = points.shape[0]
37
- order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
-
39
- x_min = np.min(points[:, 0])
40
- x_max = np.max(points[:, 0])
41
-
42
- # compute bounds of interpolation area
43
- if x_min_bound is None:
44
- x_min_bound = x_min
45
- if x_max_bound is None:
46
- x_max_bound = x_max
47
-
48
- # explicit formula for quadratic interpolation
49
- if no_points == 2 and order == 2 and plot is False:
50
- # Solution to quadratic interpolation is given by:
51
- # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
- # x_min = x1 - g1/(2a)
53
- # if x1 = 0, then is given by:
54
- # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
-
56
- if points[0, 0] == 0:
57
- x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
- else:
59
- a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
- x_sol = points[0, 0] - points[0, 2]/(2*a)
61
-
62
- x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
-
64
- # explicit formula for cubic interpolation
65
- elif no_points == 2 and order == 3 and plot is False:
66
- # Solution to cubic interpolation is given by:
67
- # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
- # d2 = sqrt(d1^2 - g1*g2)
69
- # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
- d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
- d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
72
- if np.isreal(d2):
73
- x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
74
- x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
75
- else:
76
- x_sol = (x_max_bound + x_min_bound)/2
77
-
78
- # solve linear system
79
- else:
80
- # define linear constraints
81
- A = np.zeros((0, order + 1))
82
- b = np.zeros((0, 1))
83
-
84
- # add linear constraints on function values
85
- for i in range(no_points):
86
- if not np.isnan(points[i, 1]):
87
- constraint = np.zeros((1, order + 1))
88
- for j in range(order, -1, -1):
89
- constraint[0, order - j] = points[i, 0] ** j
90
- A = np.append(A, constraint, 0)
91
- b = np.append(b, points[i, 1])
92
-
93
- # add linear constraints on gradient values
94
- for i in range(no_points):
95
- if not np.isnan(points[i, 2]):
96
- constraint = np.zeros((1, order + 1))
97
- for j in range(order):
98
- constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
99
- A = np.append(A, constraint, 0)
100
- b = np.append(b, points[i, 2])
101
-
102
- # check if system is solvable
103
- if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
104
- x_sol = (x_min_bound + x_max_bound)/2
105
- f_min = np.inf
106
- else:
107
- # solve linear system for interpolating polynomial
108
- coeff = np.linalg.solve(A, b)
109
-
110
- # compute critical points
111
- dcoeff = np.zeros(order)
112
- for i in range(len(coeff) - 1):
113
- dcoeff[i] = coeff[i] * (order - i)
114
-
115
- crit_pts = np.array([x_min_bound, x_max_bound])
116
- crit_pts = np.append(crit_pts, points[:, 0])
117
-
118
- if not np.isinf(dcoeff).any():
119
- roots = np.roots(dcoeff)
120
- crit_pts = np.append(crit_pts, roots)
121
-
122
- # test critical points
123
- f_min = np.inf
124
- x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
125
- for crit_pt in crit_pts:
126
- if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
127
- F_cp = np.polyval(coeff, crit_pt)
128
- if np.isreal(F_cp) and F_cp < f_min:
129
- x_sol = np.real(crit_pt)
130
- f_min = np.real(F_cp)
131
-
132
- if(plot):
133
- import matplotlib.pyplot as plt
134
- plt.figure()
135
- x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
136
- f = np.polyval(coeff, x)
137
- plt.plot(x, f)
138
- plt.plot(x_sol, f_min, 'x')
139
-
140
- return x_sol
141
-
142
-
143
-
144
- # class PolynomialLineSearch(LineSearch):
145
- # """TODO
146
-
147
- # Line search via polynomial interpolation.
148
-
149
- # Args:
150
- # init (float, optional): Initial step size. Defaults to 1.0.
151
- # c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
152
- # c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
153
- # maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
154
- # maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
155
- # expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
156
- # adaptive (bool, optional):
157
- # when enabled, if line search failed, initial step size is reduced.
158
- # Otherwise it is reset to initial value. Defaults to True.
159
- # plus_minus (bool, optional):
160
- # If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
161
-
162
-
163
- # Examples:
164
- # Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
165
-
166
- # .. code-block:: python
167
-
168
- # opt = tz.Modular(
169
- # model.parameters(),
170
- # tz.m.PolakRibiere(),
171
- # tz.m.StrongWolfe(c2=0.1)
172
- # )
173
-
174
- # LBFGS strong wolfe line search:
175
-
176
- # .. code-block:: python
177
-
178
- # opt = tz.Modular(
179
- # model.parameters(),
180
- # tz.m.LBFGS(),
181
- # tz.m.StrongWolfe()
182
- # )
183
-
184
- # """
185
- # def __init__(
186
- # self,
187
- # init: float = 1.0,
188
- # c1: float = 1e-4,
189
- # c2: float = 0.9,
190
- # maxiter: int = 25,
191
- # maxzoom: int = 10,
192
- # # a_max: float = 1e10,
193
- # expand: float = 2.0,
194
- # adaptive = True,
195
- # plus_minus = False,
196
- # ):
197
- # defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
198
- # expand=expand, adaptive=adaptive, plus_minus=plus_minus)
199
- # super().__init__(defaults=defaults)
200
-
201
- # self.global_state['initial_scale'] = 1.0
202
- # self.global_state['beta_scale'] = 1.0
203
-
204
- # @torch.no_grad
205
- # def search(self, update, var):
206
- # objective = self.make_objective_with_derivative(var=var)
207
-
208
- # init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
209
- # 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
210
- # 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
211
-
212
- # f_0, g_0 = objective(0)
213
-
214
- # step_size,f_a = strong_wolfe(
215
- # objective,
216
- # f_0=f_0, g_0=g_0,
217
- # init=init * self.global_state.setdefault("initial_scale", 1),
218
- # c1=c1,
219
- # c2=c2,
220
- # maxiter=maxiter,
221
- # maxzoom=maxzoom,
222
- # expand=expand,
223
- # plus_minus=plus_minus,
224
- # )
225
-
226
- # if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
227
- # if step_size is not None and step_size != 0 and not _notfinite(step_size):
228
- # self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
229
- # return step_size
230
-
231
- # # fallback to backtracking on fail
232
- # if adaptive: self.global_state['initial_scale'] *= 0.5
233
- # return 0
@@ -1,193 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
-
5
- from ...core import Module, apply_transform, Chainable
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
-
9
- class MatrixMomentum(Module):
10
- """Second order momentum method.
11
-
12
- Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
13
-
14
- .. note::
15
- :code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
16
-
17
- .. note::
18
- I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
19
- without having to tune :code:`mu`.
20
-
21
- .. note::
22
- In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
23
-
24
- .. note::
25
- This module requires the a closure passed to the optimizer step,
26
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
27
- The closure must accept a ``backward`` argument (refer to documentation).
28
-
29
- Args:
30
- mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
31
- beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
32
- hvp_method (str, optional):
33
- Determines how Hessian-vector products are evaluated.
34
-
35
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
- This requires creating a graph for the gradient.
37
- - ``"forward"``: Use a forward finite difference formula to
38
- approximate the HVP. This requires one extra gradient evaluation.
39
- - ``"central"``: Use a central finite difference formula for a
40
- more accurate HVP approximation. This requires two extra
41
- gradient evaluations.
42
- Defaults to "autograd".
43
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
44
- hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
45
-
46
- Reference:
47
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
48
- """
49
-
50
- def __init__(
51
- self,
52
- mu=0.1,
53
- beta: float = 1,
54
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
55
- h: float = 1e-3,
56
- hvp_tfm: Chainable | None = None,
57
- ):
58
- defaults = dict(mu=mu, beta=beta, hvp_method=hvp_method, h=h)
59
- super().__init__(defaults)
60
-
61
- if hvp_tfm is not None:
62
- self.set_child('hvp_tfm', hvp_tfm)
63
-
64
- def reset_for_online(self):
65
- super().reset_for_online()
66
- self.clear_state_keys('prev_update')
67
-
68
- @torch.no_grad
69
- def update(self, var):
70
- assert var.closure is not None
71
- prev_update = self.get_state(var.params, 'prev_update')
72
- hvp_method = self.settings[var.params[0]]['hvp_method']
73
- h = self.settings[var.params[0]]['h']
74
-
75
- Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
76
- Hvp = [t.detach() for t in Hvp]
77
-
78
- if 'hvp_tfm' in self.children:
79
- Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
80
-
81
- self.store(var.params, "Hvp", Hvp)
82
-
83
-
84
- @torch.no_grad
85
- def apply(self, var):
86
- update = TensorList(var.get_update())
87
- Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
88
- mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
89
-
90
- update.add_(prev_update - Hvp*mu)
91
- prev_update.set_(update * beta)
92
- var.update = update
93
- return var
94
-
95
-
96
- class AdaptiveMatrixMomentum(Module):
97
- """Second order momentum method.
98
-
99
- Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
100
-
101
- .. note::
102
- In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
103
-
104
- .. note::
105
- This module requires the a closure passed to the optimizer step,
106
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
107
- The closure must accept a ``backward`` argument (refer to documentation).
108
-
109
-
110
- Args:
111
- mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
112
- beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
113
- hvp_method (str, optional):
114
- Determines how Hessian-vector products are evaluated.
115
-
116
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
117
- This requires creating a graph for the gradient.
118
- - ``"forward"``: Use a forward finite difference formula to
119
- approximate the HVP. This requires one extra gradient evaluation.
120
- - ``"central"``: Use a central finite difference formula for a
121
- more accurate HVP approximation. This requires two extra
122
- gradient evaluations.
123
- Defaults to "autograd".
124
- h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
125
- hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
126
-
127
- Reference:
128
- Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
129
- """
130
-
131
- def __init__(
132
- self,
133
- mu_mul: float = 1,
134
- beta: float = 1,
135
- eps=1e-4,
136
- hvp_method: Literal["autograd", "forward", "central"] = "autograd",
137
- h: float = 1e-3,
138
- hvp_tfm: Chainable | None = None,
139
- ):
140
- defaults = dict(mu_mul=mu_mul, beta=beta, hvp_method=hvp_method, h=h, eps=eps)
141
- super().__init__(defaults)
142
-
143
- if hvp_tfm is not None:
144
- self.set_child('hvp_tfm', hvp_tfm)
145
-
146
- def reset_for_online(self):
147
- super().reset_for_online()
148
- self.clear_state_keys('prev_params', 'prev_grad')
149
-
150
- @torch.no_grad
151
- def update(self, var):
152
- assert var.closure is not None
153
- prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
154
-
155
- settings = self.settings[var.params[0]]
156
- hvp_method = settings['hvp_method']
157
- h = settings['h']
158
- eps = settings['eps']
159
-
160
- mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
161
-
162
- Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
163
- Hvp = [t.detach() for t in Hvp]
164
-
165
- if 'hvp_tfm' in self.children:
166
- Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
167
-
168
- # adaptive part
169
- s_k = var.params - prev_params
170
- prev_params.copy_(var.params)
171
-
172
- if hvp_method != 'central': assert var.grad is not None
173
- grad = var.get_grad()
174
- y_k = grad - prev_grad
175
- prev_grad.copy_(grad)
176
-
177
- ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
178
-
179
- self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
180
-
181
- @torch.no_grad
182
- def apply(self, var):
183
- Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
184
- Hvp = as_tensorlist(Hvp)
185
- beta = NumberList(self.settings[p]['beta'] for p in var.params)
186
- update = TensorList(var.get_update())
187
- prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
188
-
189
- update.add_(prev_update - Hvp*ada_mu)
190
- prev_update.set_(update * beta)
191
- var.update = update
192
- return var
193
-
@@ -1,165 +0,0 @@
1
- from operator import itemgetter
2
- from typing import Literal
3
-
4
- import torch
5
- from ...core import (
6
- Chainable,
7
- Module,
8
- Target,
9
- TensorwiseTransform,
10
- Transform,
11
- Var,
12
- apply_transform,
13
- )
14
- from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
15
- from ...utils.linalg import matrix_power_eigh
16
- from ..functional import add_power_, lerp_power_, root
17
-
18
-
19
- def adagrad_(
20
- tensors_: TensorList,
21
- sq_sum_: TensorList,
22
- alpha: float | NumberList,
23
- lr_decay: float | NumberList,
24
- eps: float | NumberList,
25
- step: int,
26
- pow: float = 2,
27
- use_sqrt: bool = True,
28
- divide: bool = False,
29
-
30
- # inner args
31
- inner: Module | None = None,
32
- params: list[torch.Tensor] | None = None,
33
- grads: list[torch.Tensor] | None = None,
34
- ):
35
- """returns `tensors_`"""
36
- clr = alpha / (1 + step * lr_decay)
37
-
38
- sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
39
-
40
- if inner is not None:
41
- assert params is not None
42
- tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
43
-
44
- if divide: sq_sum_ = sq_sum_ / max(step, 1)
45
-
46
- if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
47
- else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
48
-
49
- return tensors_
50
-
51
-
52
-
53
- class Adagrad(Transform):
54
- """Adagrad, divides by sum of past squares of gradients.
55
-
56
- This implementation is identical to :code:`torch.optim.Adagrad`.
57
-
58
- Args:
59
- lr_decay (float, optional): learning rate decay. Defaults to 0.
60
- initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
61
- eps (float, optional): division epsilon. Defaults to 1e-10.
62
- alpha (float, optional): step size. Defaults to 1.
63
- pow (float, optional): power for gradients and accumulator root. Defaults to 2.
64
- use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
65
- inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
66
- """
67
- def __init__(
68
- self,
69
- lr_decay: float = 0,
70
- initial_accumulator_value: float = 0,
71
- eps: float = 1e-10,
72
- alpha: float = 1,
73
- pow: float = 2,
74
- use_sqrt: bool = True,
75
- divide: bool=False,
76
- inner: Chainable | None = None,
77
- ):
78
- defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
79
- eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
80
- super().__init__(defaults=defaults, uses_grad=False)
81
-
82
- if inner is not None:
83
- self.set_child('inner', inner)
84
-
85
- @torch.no_grad
86
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
- tensors = TensorList(tensors)
88
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
89
-
90
- lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
91
-
92
- pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
93
-
94
- sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
95
-
96
- # initialize accumulator on 1st step
97
- if step == 1:
98
- sq_sum.set_(tensors.full_like([s['initial_accumulator_value'] for s in settings]))
99
-
100
- return adagrad_(
101
- tensors,
102
- sq_sum_=sq_sum,
103
- alpha=alpha,
104
- lr_decay=lr_decay,
105
- eps=eps,
106
- step=self.global_state["step"],
107
- pow=pow,
108
- use_sqrt=use_sqrt,
109
- divide=divide,
110
-
111
- # inner args
112
- inner=self.children.get("inner", None),
113
- params=params,
114
- grads=grads,
115
- )
116
-
117
-
118
-
119
- class FullMatrixAdagrad(TensorwiseTransform):
120
- def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
121
- defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
122
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
123
-
124
- @torch.no_grad
125
- def update_tensor(self, tensor, param, grad, loss, state, setting):
126
- G = tensor.ravel()
127
- GG = torch.outer(G, G)
128
- decay = setting['decay']
129
- beta = setting['beta']
130
- init = setting['init']
131
-
132
- if 'GG' not in state:
133
- if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
134
- elif init == 'zeros': state['GG'] = torch.zeros_like(GG)
135
- elif init == 'ones': state['GG'] = torch.ones_like(GG)
136
- elif init == 'GGT': state['GG'] = GG.clone()
137
- else: raise ValueError(init)
138
- if decay is not None: state['GG'].mul_(decay)
139
-
140
- if beta is not None: state['GG'].lerp_(GG, 1-beta)
141
- else: state['GG'].add_(GG)
142
- state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
143
-
144
- @torch.no_grad
145
- def apply_tensor(self, tensor, param, grad, loss, state, setting):
146
- GG = state['GG']
147
- sqrt = setting['sqrt']
148
- divide = setting['divide']
149
- if divide: GG = GG/state.get('i', 1)
150
-
151
- if tensor.numel() == 1:
152
- GG = GG.squeeze()
153
- if sqrt: return tensor / GG.sqrt()
154
- return tensor / GG
155
-
156
- try:
157
- if sqrt: B = matrix_power_eigh(GG, -1/2)
158
- else: return torch.linalg.solve(GG, tensor.ravel()).view_as(tensor) # pylint:disable = not-callable
159
-
160
- except torch.linalg.LinAlgError:
161
- scale = 1 / tensor.abs().max()
162
- return tensor.mul_(scale.clip(min=torch.finfo(tensor.dtype).eps, max=1)) # conservative scaling
163
-
164
- return (B @ tensor.ravel()).view_as(tensor)
165
-