torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,233 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .line_search import LineSearchBase
5
+
6
+
7
+ # polynomial interpolation
8
+ # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
+ """
12
+ Gives the minimizer and minimum of the interpolating polynomial over given points
13
+ based on function and derivative information. Defaults to bisection if no critical
14
+ points are valid.
15
+
16
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
+ modifications.
18
+
19
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
+ Last edited 12/6/18.
21
+
22
+ Inputs:
23
+ points (nparray): two-dimensional array with each point of form [x f g]
24
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
+ plot (bool): plot interpolating polynomial
27
+
28
+ Outputs:
29
+ x_sol (float): minimizer of interpolating polynomial
30
+ F_min (float): minimum of interpolating polynomial
31
+
32
+ Note:
33
+ . Set f or g to np.nan if they are unknown
34
+
35
+ """
36
+ no_points = points.shape[0]
37
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
+
39
+ x_min = np.min(points[:, 0])
40
+ x_max = np.max(points[:, 0])
41
+
42
+ # compute bounds of interpolation area
43
+ if x_min_bound is None:
44
+ x_min_bound = x_min
45
+ if x_max_bound is None:
46
+ x_max_bound = x_max
47
+
48
+ # explicit formula for quadratic interpolation
49
+ if no_points == 2 and order == 2 and plot is False:
50
+ # Solution to quadratic interpolation is given by:
51
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
+ # x_min = x1 - g1/(2a)
53
+ # if x1 = 0, then is given by:
54
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
+
56
+ if points[0, 0] == 0:
57
+ x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
+ else:
59
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
+ x_sol = points[0, 0] - points[0, 2]/(2*a)
61
+
62
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
+
64
+ # explicit formula for cubic interpolation
65
+ elif no_points == 2 and order == 3 and plot is False:
66
+ # Solution to cubic interpolation is given by:
67
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
+ # d2 = sqrt(d1^2 - g1*g2)
69
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
+ d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
72
+ if np.isreal(d2):
73
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
74
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
75
+ else:
76
+ x_sol = (x_max_bound + x_min_bound)/2
77
+
78
+ # solve linear system
79
+ else:
80
+ # define linear constraints
81
+ A = np.zeros((0, order + 1))
82
+ b = np.zeros((0, 1))
83
+
84
+ # add linear constraints on function values
85
+ for i in range(no_points):
86
+ if not np.isnan(points[i, 1]):
87
+ constraint = np.zeros((1, order + 1))
88
+ for j in range(order, -1, -1):
89
+ constraint[0, order - j] = points[i, 0] ** j
90
+ A = np.append(A, constraint, 0)
91
+ b = np.append(b, points[i, 1])
92
+
93
+ # add linear constraints on gradient values
94
+ for i in range(no_points):
95
+ if not np.isnan(points[i, 2]):
96
+ constraint = np.zeros((1, order + 1))
97
+ for j in range(order):
98
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
99
+ A = np.append(A, constraint, 0)
100
+ b = np.append(b, points[i, 2])
101
+
102
+ # check if system is solvable
103
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
104
+ x_sol = (x_min_bound + x_max_bound)/2
105
+ f_min = np.inf
106
+ else:
107
+ # solve linear system for interpolating polynomial
108
+ coeff = np.linalg.solve(A, b)
109
+
110
+ # compute critical points
111
+ dcoeff = np.zeros(order)
112
+ for i in range(len(coeff) - 1):
113
+ dcoeff[i] = coeff[i] * (order - i)
114
+
115
+ crit_pts = np.array([x_min_bound, x_max_bound])
116
+ crit_pts = np.append(crit_pts, points[:, 0])
117
+
118
+ if not np.isinf(dcoeff).any():
119
+ roots = np.roots(dcoeff)
120
+ crit_pts = np.append(crit_pts, roots)
121
+
122
+ # test critical points
123
+ f_min = np.inf
124
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
125
+ for crit_pt in crit_pts:
126
+ if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
127
+ F_cp = np.polyval(coeff, crit_pt)
128
+ if np.isreal(F_cp) and F_cp < f_min:
129
+ x_sol = np.real(crit_pt)
130
+ f_min = np.real(F_cp)
131
+
132
+ if(plot):
133
+ import matplotlib.pyplot as plt
134
+ plt.figure()
135
+ x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
136
+ f = np.polyval(coeff, x)
137
+ plt.plot(x, f)
138
+ plt.plot(x_sol, f_min, 'x')
139
+
140
+ return x_sol
141
+
142
+
143
+
144
+ # class PolynomialLineSearch(LineSearch):
145
+ # """TODO
146
+
147
+ # Line search via polynomial interpolation.
148
+
149
+ # Args:
150
+ # init (float, optional): Initial step size. Defaults to 1.0.
151
+ # c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
152
+ # c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
153
+ # maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
154
+ # maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
155
+ # expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
156
+ # adaptive (bool, optional):
157
+ # when enabled, if line search failed, initial step size is reduced.
158
+ # Otherwise it is reset to initial value. Defaults to True.
159
+ # plus_minus (bool, optional):
160
+ # If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
161
+
162
+
163
+ # Examples:
164
+ # Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
165
+
166
+ # .. code-block:: python
167
+
168
+ # opt = tz.Modular(
169
+ # model.parameters(),
170
+ # tz.m.PolakRibiere(),
171
+ # tz.m.StrongWolfe(c2=0.1)
172
+ # )
173
+
174
+ # LBFGS strong wolfe line search:
175
+
176
+ # .. code-block:: python
177
+
178
+ # opt = tz.Modular(
179
+ # model.parameters(),
180
+ # tz.m.LBFGS(),
181
+ # tz.m.StrongWolfe()
182
+ # )
183
+
184
+ # """
185
+ # def __init__(
186
+ # self,
187
+ # init: float = 1.0,
188
+ # c1: float = 1e-4,
189
+ # c2: float = 0.9,
190
+ # maxiter: int = 25,
191
+ # maxzoom: int = 10,
192
+ # # a_max: float = 1e10,
193
+ # expand: float = 2.0,
194
+ # adaptive = True,
195
+ # plus_minus = False,
196
+ # ):
197
+ # defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
198
+ # expand=expand, adaptive=adaptive, plus_minus=plus_minus)
199
+ # super().__init__(defaults=defaults)
200
+
201
+ # self.global_state['initial_scale'] = 1.0
202
+ # self.global_state['beta_scale'] = 1.0
203
+
204
+ # @torch.no_grad
205
+ # def search(self, update, var):
206
+ # objective = self.make_objective_with_derivative(var=var)
207
+
208
+ # init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
209
+ # 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
210
+ # 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
211
+
212
+ # f_0, g_0 = objective(0)
213
+
214
+ # step_size,f_a = strong_wolfe(
215
+ # objective,
216
+ # f_0=f_0, g_0=g_0,
217
+ # init=init * self.global_state.setdefault("initial_scale", 1),
218
+ # c1=c1,
219
+ # c2=c2,
220
+ # maxiter=maxiter,
221
+ # maxzoom=maxzoom,
222
+ # expand=expand,
223
+ # plus_minus=plus_minus,
224
+ # )
225
+
226
+ # if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
227
+ # if step_size is not None and step_size != 0 and not _notfinite(step_size):
228
+ # self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
229
+ # return step_size
230
+
231
+ # # fallback to backtracking on fail
232
+ # if adaptive: self.global_state['initial_scale'] *= 0.5
233
+ # return 0
@@ -3,10 +3,10 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from .line_search import LineSearch
6
+ from .line_search import LineSearchBase
7
7
 
8
8
 
9
- class ScipyMinimizeScalar(LineSearch):
9
+ class ScipyMinimizeScalar(LineSearchBase):
10
10
  """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
11
11
 
12
12
  Args:
@@ -1,3 +1,4 @@
1
+ """this needs to be reworked maybe but it also works"""
1
2
  import math
2
3
  import warnings
3
4
  from operator import itemgetter
@@ -5,8 +6,7 @@ from operator import itemgetter
5
6
  import torch
6
7
  from torch.optim.lbfgs import _cubic_interpolate
7
8
 
8
- from .line_search import LineSearch
9
- from .backtracking import backtracking_line_search
9
+ from .line_search import LineSearchBase
10
10
  from ...utils import totensor
11
11
 
12
12
 
@@ -182,7 +182,7 @@ def _notfinite(x):
182
182
  if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
183
183
  return not math.isfinite(x)
184
184
 
185
- class StrongWolfe(LineSearch):
185
+ class StrongWolfe(LineSearchBase):
186
186
  """Cubic interpolation line search satisfying Strong Wolfe condition.
187
187
 
188
188
  Args:
@@ -192,11 +192,36 @@ class StrongWolfe(LineSearch):
192
192
  maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
193
193
  maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
194
194
  expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
195
+ use_prev (bool, optional):
196
+ if True, previous step size is used as the initial step size on the next step.
195
197
  adaptive (bool, optional):
196
198
  when enabled, if line search failed, initial step size is reduced.
197
199
  Otherwise it is reset to initial value. Defaults to True.
198
200
  plus_minus (bool, optional):
199
201
  If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
202
+
203
+
204
+ Examples:
205
+ Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
206
+
207
+ .. code-block:: python
208
+
209
+ opt = tz.Modular(
210
+ model.parameters(),
211
+ tz.m.PolakRibiere(),
212
+ tz.m.StrongWolfe(c2=0.1)
213
+ )
214
+
215
+ LBFGS strong wolfe line search:
216
+
217
+ .. code-block:: python
218
+
219
+ opt = tz.Modular(
220
+ model.parameters(),
221
+ tz.m.LBFGS(),
222
+ tz.m.StrongWolfe()
223
+ )
224
+
200
225
  """
201
226
  def __init__(
202
227
  self,
@@ -207,11 +232,12 @@ class StrongWolfe(LineSearch):
207
232
  maxzoom: int = 10,
208
233
  # a_max: float = 1e10,
209
234
  expand: float = 2.0,
235
+ use_prev: bool = False,
210
236
  adaptive = True,
211
237
  plus_minus = False,
212
238
  ):
213
239
  defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
214
- expand=expand, adaptive=adaptive, plus_minus=plus_minus)
240
+ expand=expand, adaptive=adaptive, plus_minus=plus_minus,use_prev=use_prev)
215
241
  super().__init__(defaults=defaults)
216
242
 
217
243
  self.global_state['initial_scale'] = 1.0
@@ -221,11 +247,12 @@ class StrongWolfe(LineSearch):
221
247
  def search(self, update, var):
222
248
  objective = self.make_objective_with_derivative(var=var)
223
249
 
224
- init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
250
+ init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus, use_prev = itemgetter(
225
251
  'init', 'c1', 'c2', 'maxiter', 'maxzoom',
226
- 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
252
+ 'expand', 'adaptive', 'plus_minus', 'use_prev')(self.settings[var.params[0]])
227
253
 
228
254
  f_0, g_0 = objective(0)
255
+ if use_prev: init = self.global_state.get('prev_alpha', init)
229
256
 
230
257
  step_size,f_a = strong_wolfe(
231
258
  objective,
@@ -242,8 +269,8 @@ class StrongWolfe(LineSearch):
242
269
  if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
243
270
  if step_size is not None and step_size != 0 and not _notfinite(step_size):
244
271
  self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
272
+ self.global_state['prev_alpha'] = step_size
245
273
  return step_size
246
274
 
247
- # fallback to backtracking on fail
248
275
  if adaptive: self.global_state['initial_scale'] *= 0.5
249
276
  return 0
@@ -0,0 +1,27 @@
1
+ from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
2
+ from .escape import EscapeAnnealing
3
+ from .gradient_accumulation import GradientAccumulation
4
+ from .misc import (
5
+ DivByLoss,
6
+ FillLoss,
7
+ GradSign,
8
+ GraftGradToUpdate,
9
+ GraftToGrad,
10
+ GraftToParams,
11
+ HpuEstimate,
12
+ LastAbsoluteRatio,
13
+ LastDifference,
14
+ LastGradDifference,
15
+ LastProduct,
16
+ LastRatio,
17
+ MulByLoss,
18
+ NoiseSign,
19
+ Previous,
20
+ RandomHvp,
21
+ Relative,
22
+ UpdateSign,
23
+ )
24
+ from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
25
+ from .regularization import Dropout, PerturbWeights, WeightDropout
26
+ from .split import Split
27
+ from .switch import Alternate, Switch
@@ -6,6 +6,7 @@ from ...core import Module
6
6
  from ...utils.tensorlist import Distributions
7
7
 
8
8
  class PrintUpdate(Module):
9
+ """Prints current update."""
9
10
  def __init__(self, text = 'update = ', print_fn = print):
10
11
  defaults = dict(text=text, print_fn=print_fn)
11
12
  super().__init__(defaults)
@@ -15,6 +16,7 @@ class PrintUpdate(Module):
15
16
  return var
16
17
 
17
18
  class PrintShape(Module):
19
+ """Prints shapes of the update."""
18
20
  def __init__(self, text = 'shapes = ', print_fn = print):
19
21
  defaults = dict(text=text, print_fn=print_fn)
20
22
  super().__init__(defaults)
@@ -22,4 +24,25 @@ class PrintShape(Module):
22
24
  def step(self, var):
23
25
  shapes = [u.shape for u in var.update] if var.update is not None else None
24
26
  self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{shapes}')
25
- return var
27
+ return var
28
+
29
+ class PrintParams(Module):
30
+ """Prints current update."""
31
+ def __init__(self, text = 'params = ', print_fn = print):
32
+ defaults = dict(text=text, print_fn=print_fn)
33
+ super().__init__(defaults)
34
+
35
+ def step(self, var):
36
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.params}')
37
+ return var
38
+
39
+
40
+ class PrintLoss(Module):
41
+ """Prints var.get_loss()."""
42
+ def __init__(self, text = 'loss = ', print_fn = print):
43
+ defaults = dict(text=text, print_fn=print_fn)
44
+ super().__init__(defaults)
45
+
46
+ def step(self, var):
47
+ self.settings[var.params[0]]["print_fn"](f'{self.settings[var.params[0]]["text"]}{var.get_loss(False)}')
48
+ return var
@@ -0,0 +1,60 @@
1
+ import torch
2
+
3
+ from ...core import Module
4
+ from ...utils import TensorList, NumberList
5
+
6
+
7
+ class EscapeAnnealing(Module):
8
+ """If parameters stop changing, this runs a backward annealing random search"""
9
+ def __init__(self, max_region:float = 1, max_iter:int = 1000, tol=1e-6, n_tol: int = 10):
10
+ defaults = dict(max_region=max_region, max_iter=max_iter, tol=tol, n_tol=n_tol)
11
+ super().__init__(defaults)
12
+
13
+
14
+ @torch.no_grad
15
+ def step(self, var):
16
+ closure = var.closure
17
+ if closure is None: raise RuntimeError("Escape requries closure")
18
+
19
+ params = TensorList(var.params)
20
+ settings = self.settings[params[0]]
21
+ max_region = self.get_settings(params, 'max_region', cls=NumberList)
22
+ max_iter = settings['max_iter']
23
+ tol = settings['tol']
24
+ n_tol = settings['n_tol']
25
+
26
+ n_bad = self.global_state.get('n_bad', 0)
27
+
28
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
29
+ diff = params-prev_params
30
+ prev_params.copy_(params)
31
+
32
+ if diff.abs().global_max() <= tol:
33
+ n_bad += 1
34
+
35
+ else:
36
+ n_bad = 0
37
+
38
+ self.global_state['n_bad'] = n_bad
39
+
40
+ # no progress
41
+ f_0 = var.get_loss(False)
42
+ if n_bad >= n_tol:
43
+ for i in range(1, max_iter+1):
44
+ alpha = max_region * (i / max_iter)
45
+ pert = params.sample_like(distribution='sphere').mul_(alpha)
46
+
47
+ params.add_(pert)
48
+ f_star = closure(False)
49
+
50
+ if f_star < f_0-1e-10:
51
+ var.update = None
52
+ var.stop = True
53
+ var.skip_update = True
54
+ return var
55
+
56
+ else:
57
+ params.sub_(pert)
58
+
59
+ self.global_state['n_bad'] = 0
60
+ return var
@@ -0,0 +1,70 @@
1
+ import torch
2
+
3
+ from ...core import Chainable, Module
4
+
5
+
6
+ class GradientAccumulation(Module):
7
+ """Uses :code:`n` steps to accumulate gradients, after :code:`n` gradients have been accumulated, they are passed to :code:`modules` and parameters are updates.
8
+
9
+ Accumulating gradients for :code:`n` steps is equivalent to increasing batch size by :code:`n`. Increasing the batch size
10
+ is more computationally efficient, but sometimes it is not feasible due to memory constraints.
11
+
12
+ .. note::
13
+ Technically this can accumulate any inputs, including updates generated by previous modules. As long as this module is first, it will accumulate the gradients.
14
+
15
+ Args:
16
+ modules (Chainable): modules that perform a step every :code:`n` steps using the accumulated gradients.
17
+ n (int): number of gradients to accumulate.
18
+ mean (bool, optional): if True, uses mean of accumulated gradients, otherwise uses sum. Defaults to True.
19
+ stop (bool, optional):
20
+ this module prevents next modules from stepping unless :code:`n` gradients have been accumulate. Setting this argument to False disables that. Defaults to True.
21
+
22
+ Examples:
23
+ Adam with gradients accumulated for 16 batches.
24
+
25
+ .. code-block:: python
26
+
27
+ opt = tz.Modular(
28
+ model.parameters(),
29
+ tz.m.GradientAccumulation(
30
+ modules=[tz.m.Adam(), tz.m.LR(1e-2)],
31
+ n=16
32
+ )
33
+ )
34
+
35
+ """
36
+ def __init__(self, modules: Chainable, n: int, mean=True, stop=True):
37
+ defaults = dict(n=n, mean=mean, stop=stop)
38
+ super().__init__(defaults)
39
+ self.set_child('modules', modules)
40
+
41
+
42
+ @torch.no_grad
43
+ def step(self, var):
44
+ accumulator = self.get_state(var.params, 'accumulator')
45
+ settings = self.settings[var.params[0]]
46
+ n = settings['n']; mean = settings['mean']; stop = settings['stop']
47
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
48
+
49
+ # add update to accumulator
50
+ torch._foreach_add_(accumulator, var.get_update())
51
+
52
+ # step with accumulated updates
53
+ if step % n == 0:
54
+ if mean:
55
+ torch._foreach_div_(accumulator, n)
56
+
57
+ var.update = [a.clone() for a in accumulator]
58
+ var = self.children['modules'].step(var)
59
+
60
+ # zero accumulator
61
+ torch._foreach_zero_(accumulator)
62
+
63
+ else:
64
+ # prevent update
65
+ if stop:
66
+ var.stop=True
67
+ var.skip_update=True
68
+
69
+ return var
70
+