torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,16 @@ from typing import Any
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
- from ...core import Module, Target, Vars
11
+ from ...core import Module, Target, Var
12
12
  from ...utils import tofloat
13
13
 
14
14
 
15
15
  class MaxLineSearchItersReached(Exception): pass
16
16
 
17
17
 
18
- class LineSearch(Module, ABC):
18
+ class LineSearchBase(Module, ABC):
19
19
  """Base class for line searches.
20
+
20
21
  This is an abstract class, to use it, subclass it and override `search`.
21
22
 
22
23
  Args:
@@ -26,6 +27,62 @@ class LineSearch(Module, ABC):
26
27
  the objective this many times, and step size with the lowest loss value will be used.
27
28
  This is useful when passing `make_objective` to an external library which
28
29
  doesn't have a maxiter option. Defaults to None.
30
+
31
+ Other useful methods:
32
+ * `evaluate_step_size` - returns loss with a given scalar step size
33
+ * `evaluate_step_size_loss_and_derivative` - returns loss and directional derivative with a given scalar step size
34
+ * `make_objective` - creates a function that accepts a scalar step size and returns loss. This can be passed to a scalar solver, such as scipy.optimize.minimize_scalar.
35
+ * `make_objective_with_derivative` - creates a function that accepts a scalar step size and returns a tuple with loss and directional derivative. This can be passed to a scalar solver.
36
+
37
+ Examples:
38
+ #### Basic line search
39
+
40
+ This evaluates all step sizes in a range by using the :code:`self.evaluate_step_size` method.
41
+
42
+ .. code-block:: python
43
+
44
+ class GridLineSearch(LineSearch):
45
+ def __init__(self, start, end, num):
46
+ defaults = dict(start=start,end=end,num=num)
47
+ super().__init__(defaults)
48
+
49
+ @torch.no_grad
50
+ def search(self, update, var):
51
+ settings = self.settings[var.params[0]]
52
+ start = settings["start"]
53
+ end = settings["end"]
54
+ num = settings["num"]
55
+
56
+ lowest_loss = float("inf")
57
+ best_step_size = best_step_size
58
+
59
+ for step_size in torch.linspace(start,end,num):
60
+ loss = self.evaluate_step_size(step_size.item(), var=var, backward=False)
61
+ if loss < lowest_loss:
62
+ lowest_loss = loss
63
+ best_step_size = step_size
64
+
65
+ return best_step_size
66
+
67
+ #### Using external solver via self.make_objective
68
+
69
+ Here we let :code:`scipy.optimize.minimize_scalar` solver find the best step size via :code:`self.make_objective`
70
+
71
+ .. code-block:: python
72
+
73
+ class ScipyMinimizeScalar(LineSearch):
74
+ def __init__(self, method: str | None = None):
75
+ defaults = dict(method=method)
76
+ super().__init__(defaults)
77
+
78
+ @torch.no_grad
79
+ def search(self, update, var):
80
+ objective = self.make_objective(var=var)
81
+ method = self.settings[var.params[0]]["method"]
82
+
83
+ res = self.scopt.minimize_scalar(objective, method=method)
84
+ return res.x
85
+
29
86
  """
30
87
  def __init__(self, defaults: dict[str, Any] | None, maxiter: int | None = None):
31
88
  super().__init__(defaults)
@@ -62,12 +119,12 @@ class LineSearch(Module, ABC):
62
119
  if any(a!=0 for a in alpha):
63
120
  torch._foreach_add_(params, torch._foreach_mul(update, alpha))
64
121
 
65
- def _loss(self, step_size: float, vars: Vars, closure, params: list[torch.Tensor],
122
+ def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
66
123
  update: list[torch.Tensor], backward:bool=False) -> float:
67
124
 
68
125
  # if step_size is 0, we might already know the loss
69
- if (vars.loss is not None) and (step_size == 0):
70
- return tofloat(vars.loss)
126
+ if (var.loss is not None) and (step_size == 0):
127
+ return tofloat(var.loss)
71
128
 
72
129
  # check max iter
73
130
  if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
@@ -85,23 +142,23 @@ class LineSearch(Module, ABC):
85
142
  self._lowest_loss = tofloat(loss)
86
143
  self._best_step_size = step_size
87
144
 
88
- # if evaluated loss at step size 0, set it to vars.loss
145
+ # if evaluated loss at step size 0, set it to var.loss
89
146
  if step_size == 0:
90
- vars.loss = loss
91
- if backward: vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
147
+ var.loss = loss
148
+ if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
92
149
 
93
150
  return tofloat(loss)
94
151
 
95
- def _loss_derivative(self, step_size: float, vars: Vars, closure,
152
+ def _loss_derivative(self, step_size: float, var: Var, closure,
96
153
  params: list[torch.Tensor], update: list[torch.Tensor]):
97
154
  # if step_size is 0, we might already know the derivative
98
- if (vars.grad is not None) and (step_size == 0):
99
- loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=False)
100
- derivative = - sum(t.sum() for t in torch._foreach_mul(vars.grad, update))
155
+ if (var.grad is not None) and (step_size == 0):
156
+ loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
157
+ derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
101
158
 
102
159
  else:
103
160
  # loss with a backward pass sets params.grad
104
- loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=True)
161
+ loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)
105
162
 
106
163
  # directional derivative
107
164
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
@@ -109,73 +166,74 @@ class LineSearch(Module, ABC):
109
166
 
110
167
  return loss, tofloat(derivative)
111
168
 
112
- def evaluate_step_size(self, step_size: float, vars: Vars, backward:bool=False):
113
- closure = vars.closure
169
+ def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
170
+ closure = var.closure
114
171
  if closure is None: raise RuntimeError('line search requires closure')
115
- return self._loss(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update(),backward=backward)
172
+ return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
116
173
 
117
- def evaluate_step_size_loss_and_derivative(self, step_size: float, vars: Vars):
118
- closure = vars.closure
174
+ def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
175
+ closure = var.closure
119
176
  if closure is None: raise RuntimeError('line search requires closure')
120
- return self._loss_derivative(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update())
177
+ return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
121
178
 
122
- def make_objective(self, vars: Vars, backward:bool=False):
123
- closure = vars.closure
179
+ def make_objective(self, var: Var, backward:bool=False):
180
+ closure = var.closure
124
181
  if closure is None: raise RuntimeError('line search requires closure')
125
- return partial(self._loss, vars=vars, closure=closure, params=vars.params, update=vars.get_update(), backward=backward)
182
+ return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
126
183
 
127
- def make_objective_with_derivative(self, vars: Vars):
128
- closure = vars.closure
184
+ def make_objective_with_derivative(self, var: Var):
185
+ closure = var.closure
129
186
  if closure is None: raise RuntimeError('line search requires closure')
130
- return partial(self._loss_derivative, vars=vars, closure=closure, params=vars.params, update=vars.get_update())
187
+ return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
131
188
 
132
189
  @abstractmethod
133
- def search(self, update: list[torch.Tensor], vars: Vars) -> float:
190
+ def search(self, update: list[torch.Tensor], var: Var) -> float:
134
191
  """Finds the step size to use"""
135
192
 
136
193
  @torch.no_grad
137
- def step(self, vars: Vars) -> Vars:
194
+ def step(self, var: Var) -> Var:
138
195
  self._reset()
139
- params = vars.params
140
- update = vars.get_update()
196
+ params = var.params
197
+ update = var.get_update()
141
198
 
142
199
  try:
143
- step_size = self.search(update=update, vars=vars)
200
+ step_size = self.search(update=update, var=var)
144
201
  except MaxLineSearchItersReached:
145
202
  step_size = self._best_step_size
146
203
 
147
204
  # set loss_approx
148
- if vars.loss_approx is None: vars.loss_approx = self._lowest_loss
205
+ if var.loss_approx is None: var.loss_approx = self._lowest_loss
149
206
 
150
207
  # this is last module - set step size to found step_size times lr
151
- if vars.is_last:
208
+ if var.is_last:
152
209
 
153
- if vars.last_module_lrs is None:
210
+ if var.last_module_lrs is None:
154
211
  self.set_step_size_(step_size, params=params, update=update)
155
212
 
156
213
  else:
157
- self._set_per_parameter_step_size_([step_size*lr for lr in vars.last_module_lrs], params=params, update=update)
214
+ self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
158
215
 
159
- vars.stop = True; vars.skip_update = True
160
- return vars
216
+ var.stop = True; var.skip_update = True
217
+ return var
161
218
 
162
219
  # revert parameters and multiply update by step size
163
220
  self.set_step_size_(0, params=params, update=update)
164
- torch._foreach_mul_(vars.update, step_size)
165
- return vars
221
+ torch._foreach_mul_(var.update, step_size)
222
+ return var
166
223
 
167
224
 
168
- class GridLineSearch(LineSearch):
169
- """Mostly for testing, this is not practical"""
170
- def __init__(self, start, end, num):
171
- defaults = dict(start=start,end=end,num=num)
172
- super().__init__(defaults)
173
225
 
174
- @torch.no_grad
175
- def search(self, update, vars):
176
- start,end,num=itemgetter('start','end','num')(self.settings[vars.params[0]])
226
+ # class GridLineSearch(LineSearch):
227
+ # """Mostly for testing, this is not practical"""
228
+ # def __init__(self, start, end, num):
229
+ # defaults = dict(start=start,end=end,num=num)
230
+ # super().__init__(defaults)
231
+
232
+ # @torch.no_grad
233
+ # def search(self, update, var):
234
+ # start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
177
235
 
178
- for lr in torch.linspace(start,end,num):
179
- self.evaluate_step_size(lr.item(), vars=vars, backward=False)
236
+ # for lr in torch.linspace(start,end,num):
237
+ # self.evaluate_step_size(lr.item(), var=var, backward=False)
180
238
 
181
- return self._best_step_size
239
+ # return self._best_step_size
@@ -0,0 +1,233 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .line_search import LineSearchBase
5
+
6
+
7
+ # polynomial interpolation
8
+ # this code is from https://github.com/hjmshi/PyTorch-LBFGS/blob/master/functions/LBFGS.py
9
+ # PyTorch-LBFGS: A PyTorch Implementation of L-BFGS
10
+ def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
11
+ """
12
+ Gives the minimizer and minimum of the interpolating polynomial over given points
13
+ based on function and derivative information. Defaults to bisection if no critical
14
+ points are valid.
15
+
16
+ Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
17
+ modifications.
18
+
19
+ Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
20
+ Last edited 12/6/18.
21
+
22
+ Inputs:
23
+ points (nparray): two-dimensional array with each point of form [x f g]
24
+ x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
25
+ x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
26
+ plot (bool): plot interpolating polynomial
27
+
28
+ Outputs:
29
+ x_sol (float): minimizer of interpolating polynomial
30
+ F_min (float): minimum of interpolating polynomial
31
+
32
+ Note:
33
+ . Set f or g to np.nan if they are unknown
34
+
35
+ """
36
+ no_points = points.shape[0]
37
+ order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
38
+
39
+ x_min = np.min(points[:, 0])
40
+ x_max = np.max(points[:, 0])
41
+
42
+ # compute bounds of interpolation area
43
+ if x_min_bound is None:
44
+ x_min_bound = x_min
45
+ if x_max_bound is None:
46
+ x_max_bound = x_max
47
+
48
+ # explicit formula for quadratic interpolation
49
+ if no_points == 2 and order == 2 and plot is False:
50
+ # Solution to quadratic interpolation is given by:
51
+ # a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
52
+ # x_min = x1 - g1/(2a)
53
+ # if x1 = 0, then is given by:
54
+ # x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
55
+
56
+ if points[0, 0] == 0:
57
+ x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
58
+ else:
59
+ a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
60
+ x_sol = points[0, 0] - points[0, 2]/(2*a)
61
+
62
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
63
+
64
+ # explicit formula for cubic interpolation
65
+ elif no_points == 2 and order == 3 and plot is False:
66
+ # Solution to cubic interpolation is given by:
67
+ # d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
68
+ # d2 = sqrt(d1^2 - g1*g2)
69
+ # x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
70
+ d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
71
+ d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
72
+ if np.isreal(d2):
73
+ x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
74
+ x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
75
+ else:
76
+ x_sol = (x_max_bound + x_min_bound)/2
77
+
78
+ # solve linear system
79
+ else:
80
+ # define linear constraints
81
+ A = np.zeros((0, order + 1))
82
+ b = np.zeros((0, 1))
83
+
84
+ # add linear constraints on function values
85
+ for i in range(no_points):
86
+ if not np.isnan(points[i, 1]):
87
+ constraint = np.zeros((1, order + 1))
88
+ for j in range(order, -1, -1):
89
+ constraint[0, order - j] = points[i, 0] ** j
90
+ A = np.append(A, constraint, 0)
91
+ b = np.append(b, points[i, 1])
92
+
93
+ # add linear constraints on gradient values
94
+ for i in range(no_points):
95
+ if not np.isnan(points[i, 2]):
96
+ constraint = np.zeros((1, order + 1))
97
+ for j in range(order):
98
+ constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
99
+ A = np.append(A, constraint, 0)
100
+ b = np.append(b, points[i, 2])
101
+
102
+ # check if system is solvable
103
+ if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
104
+ x_sol = (x_min_bound + x_max_bound)/2
105
+ f_min = np.inf
106
+ else:
107
+ # solve linear system for interpolating polynomial
108
+ coeff = np.linalg.solve(A, b)
109
+
110
+ # compute critical points
111
+ dcoeff = np.zeros(order)
112
+ for i in range(len(coeff) - 1):
113
+ dcoeff[i] = coeff[i] * (order - i)
114
+
115
+ crit_pts = np.array([x_min_bound, x_max_bound])
116
+ crit_pts = np.append(crit_pts, points[:, 0])
117
+
118
+ if not np.isinf(dcoeff).any():
119
+ roots = np.roots(dcoeff)
120
+ crit_pts = np.append(crit_pts, roots)
121
+
122
+ # test critical points
123
+ f_min = np.inf
124
+ x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
125
+ for crit_pt in crit_pts:
126
+ if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
127
+ F_cp = np.polyval(coeff, crit_pt)
128
+ if np.isreal(F_cp) and F_cp < f_min:
129
+ x_sol = np.real(crit_pt)
130
+ f_min = np.real(F_cp)
131
+
132
+ if(plot):
133
+ import matplotlib.pyplot as plt
134
+ plt.figure()
135
+ x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
136
+ f = np.polyval(coeff, x)
137
+ plt.plot(x, f)
138
+ plt.plot(x_sol, f_min, 'x')
139
+
140
+ return x_sol
141
+
142
+
143
+
144
+ # class PolynomialLineSearch(LineSearch):
145
+ # """TODO
146
+
147
+ # Line search via polynomial interpolation.
148
+
149
+ # Args:
150
+ # init (float, optional): Initial step size. Defaults to 1.0.
151
+ # c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
152
+ # c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
153
+ # maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
154
+ # maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
155
+ # expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
156
+ # adaptive (bool, optional):
157
+ # when enabled, if line search failed, initial step size is reduced.
158
+ # Otherwise it is reset to initial value. Defaults to True.
159
+ # plus_minus (bool, optional):
160
+ # If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
161
+
162
+
163
+ # Examples:
164
+ # Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
165
+
166
+ # .. code-block:: python
167
+
168
+ # opt = tz.Modular(
169
+ # model.parameters(),
170
+ # tz.m.PolakRibiere(),
171
+ # tz.m.StrongWolfe(c2=0.1)
172
+ # )
173
+
174
+ # LBFGS strong wolfe line search:
175
+
176
+ # .. code-block:: python
177
+
178
+ # opt = tz.Modular(
179
+ # model.parameters(),
180
+ # tz.m.LBFGS(),
181
+ # tz.m.StrongWolfe()
182
+ # )
183
+
184
+ # """
185
+ # def __init__(
186
+ # self,
187
+ # init: float = 1.0,
188
+ # c1: float = 1e-4,
189
+ # c2: float = 0.9,
190
+ # maxiter: int = 25,
191
+ # maxzoom: int = 10,
192
+ # # a_max: float = 1e10,
193
+ # expand: float = 2.0,
194
+ # adaptive = True,
195
+ # plus_minus = False,
196
+ # ):
197
+ # defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
198
+ # expand=expand, adaptive=adaptive, plus_minus=plus_minus)
199
+ # super().__init__(defaults=defaults)
200
+
201
+ # self.global_state['initial_scale'] = 1.0
202
+ # self.global_state['beta_scale'] = 1.0
203
+
204
+ # @torch.no_grad
205
+ # def search(self, update, var):
206
+ # objective = self.make_objective_with_derivative(var=var)
207
+
208
+ # init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
209
+ # 'init', 'c1', 'c2', 'maxiter', 'maxzoom',
210
+ # 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
211
+
212
+ # f_0, g_0 = objective(0)
213
+
214
+ # step_size,f_a = strong_wolfe(
215
+ # objective,
216
+ # f_0=f_0, g_0=g_0,
217
+ # init=init * self.global_state.setdefault("initial_scale", 1),
218
+ # c1=c1,
219
+ # c2=c2,
220
+ # maxiter=maxiter,
221
+ # maxzoom=maxzoom,
222
+ # expand=expand,
223
+ # plus_minus=plus_minus,
224
+ # )
225
+
226
+ # if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
227
+ # if step_size is not None and step_size != 0 and not _notfinite(step_size):
228
+ # self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
229
+ # return step_size
230
+
231
+ # # fallback to backtracking on fail
232
+ # if adaptive: self.global_state['initial_scale'] *= 0.5
233
+ # return 0
@@ -3,10 +3,25 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from .line_search import LineSearch
6
+ from .line_search import LineSearchBase
7
7
 
8
8
 
9
- class ScipyMinimizeScalar(LineSearch):
9
+ class ScipyMinimizeScalar(LineSearchBase):
10
+ """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
11
+
12
+ Args:
13
+ method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
14
+ maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
15
+ bracket (Sequence | None, optional):
16
+ Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
17
+ bounds (Sequence | None, optional):
18
+ For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
19
+ tol (float | None, optional): Tolerance for termination. Defaults to None.
20
+ options (dict | None, optional): A dictionary of solver options. Defaults to None.
21
+
22
+ For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
23
+
24
+ """
10
25
  def __init__(
11
26
  self,
12
27
  method: str | None = None,
@@ -24,10 +39,10 @@ class ScipyMinimizeScalar(LineSearch):
24
39
 
25
40
 
26
41
  @torch.no_grad
27
- def search(self, update, vars):
28
- objective = self.make_objective(vars=vars)
42
+ def search(self, update, var):
43
+ objective = self.make_objective(var=var)
29
44
  method, bracket, bounds, tol, options, maxiter = itemgetter(
30
- 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[vars.params[0]])
45
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
31
46
 
32
47
  if maxiter is not None:
33
48
  options = dict(options) if isinstance(options, Mapping) else {}
@@ -1,3 +1,4 @@
1
+ """this needs to be reworked maybe but it also works"""
1
2
  import math
2
3
  import warnings
3
4
  from operator import itemgetter
@@ -5,8 +6,7 @@ from operator import itemgetter
5
6
  import torch
6
7
  from torch.optim.lbfgs import _cubic_interpolate
7
8
 
8
- from .line_search import LineSearch
9
- from .backtracking import backtracking_line_search
9
+ from .line_search import LineSearchBase
10
10
  from ...utils import totensor
11
11
 
12
12
 
@@ -182,7 +182,47 @@ def _notfinite(x):
182
182
  if isinstance(x, torch.Tensor): return not torch.isfinite(x).all()
183
183
  return not math.isfinite(x)
184
184
 
185
- class StrongWolfe(LineSearch):
185
+ class StrongWolfe(LineSearchBase):
186
+ """Cubic interpolation line search satisfying Strong Wolfe condition.
187
+
188
+ Args:
189
+ init (float, optional): Initial step size. Defaults to 1.0.
190
+ c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
191
+ c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
192
+ maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
193
+ maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
194
+ expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
195
+ use_prev (bool, optional):
196
+ if True, previous step size is used as the initial step size on the next step.
197
+ adaptive (bool, optional):
198
+ when enabled, if line search failed, initial step size is reduced.
199
+ Otherwise it is reset to initial value. Defaults to True.
200
+ plus_minus (bool, optional):
201
+ If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
202
+
203
+
204
+ Examples:
205
+ Conjugate gradient method with strong wolfe line search. Nocedal, Wright recommend setting c2 to 0.1 for CG.
206
+
207
+ .. code-block:: python
208
+
209
+ opt = tz.Modular(
210
+ model.parameters(),
211
+ tz.m.PolakRibiere(),
212
+ tz.m.StrongWolfe(c2=0.1)
213
+ )
214
+
215
+ LBFGS strong wolfe line search:
216
+
217
+ .. code-block:: python
218
+
219
+ opt = tz.Modular(
220
+ model.parameters(),
221
+ tz.m.LBFGS(),
222
+ tz.m.StrongWolfe()
223
+ )
224
+
225
+ """
186
226
  def __init__(
187
227
  self,
188
228
  init: float = 1.0,
@@ -192,26 +232,27 @@ class StrongWolfe(LineSearch):
192
232
  maxzoom: int = 10,
193
233
  # a_max: float = 1e10,
194
234
  expand: float = 2.0,
235
+ use_prev: bool = False,
195
236
  adaptive = True,
196
- fallback = False,
197
237
  plus_minus = False,
198
238
  ):
199
239
  defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
200
- expand=expand, adaptive=adaptive, fallback=fallback, plus_minus=plus_minus)
240
+ expand=expand, adaptive=adaptive, plus_minus=plus_minus,use_prev=use_prev)
201
241
  super().__init__(defaults=defaults)
202
242
 
203
243
  self.global_state['initial_scale'] = 1.0
204
244
  self.global_state['beta_scale'] = 1.0
205
245
 
206
246
  @torch.no_grad
207
- def search(self, update, vars):
208
- objective = self.make_objective_with_derivative(vars=vars)
247
+ def search(self, update, var):
248
+ objective = self.make_objective_with_derivative(var=var)
209
249
 
210
- init, c1, c2, maxiter, maxzoom, expand, adaptive, fallback, plus_minus = itemgetter(
250
+ init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus, use_prev = itemgetter(
211
251
  'init', 'c1', 'c2', 'maxiter', 'maxzoom',
212
- 'expand', 'adaptive', 'fallback', 'plus_minus')(self.settings[vars.params[0]])
252
+ 'expand', 'adaptive', 'plus_minus', 'use_prev')(self.settings[var.params[0]])
213
253
 
214
254
  f_0, g_0 = objective(0)
255
+ if use_prev: init = self.global_state.get('prev_alpha', init)
215
256
 
216
257
  step_size,f_a = strong_wolfe(
217
258
  objective,
@@ -228,33 +269,8 @@ class StrongWolfe(LineSearch):
228
269
  if f_a is not None and (f_a > f_0 or _notfinite(f_a)): step_size = None
229
270
  if step_size is not None and step_size != 0 and not _notfinite(step_size):
230
271
  self.global_state['initial_scale'] = min(1.0, self.global_state['initial_scale'] * math.sqrt(2))
272
+ self.global_state['prev_alpha'] = step_size
231
273
  return step_size
232
274
 
233
- # fallback to backtracking on fail
234
275
  if adaptive: self.global_state['initial_scale'] *= 0.5
235
- if not fallback: return 0
236
-
237
- objective = self.make_objective(vars=vars)
238
-
239
- # # directional derivative
240
- g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
241
-
242
- step_size = backtracking_line_search(
243
- objective,
244
- g_0,
245
- init=init * self.global_state["initial_scale"],
246
- beta=0.5 * self.global_state["beta_scale"],
247
- c=c1,
248
- maxiter=maxiter * 2,
249
- a_min=None,
250
- try_negative=plus_minus,
251
- )
252
-
253
- # found an alpha that reduces loss
254
- if step_size is not None:
255
- self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
256
- return step_size
257
-
258
- # on fail reduce beta scale value
259
- self.global_state['beta_scale'] /= 1.5
260
- return 0
276
+ return 0
@@ -0,0 +1,27 @@
1
+ from .debug import PrintLoss, PrintParams, PrintShape, PrintUpdate
2
+ from .escape import EscapeAnnealing
3
+ from .gradient_accumulation import GradientAccumulation
4
+ from .misc import (
5
+ DivByLoss,
6
+ FillLoss,
7
+ GradSign,
8
+ GraftGradToUpdate,
9
+ GraftToGrad,
10
+ GraftToParams,
11
+ HpuEstimate,
12
+ LastAbsoluteRatio,
13
+ LastDifference,
14
+ LastGradDifference,
15
+ LastProduct,
16
+ LastRatio,
17
+ MulByLoss,
18
+ NoiseSign,
19
+ Previous,
20
+ RandomHvp,
21
+ Relative,
22
+ UpdateSign,
23
+ )
24
+ from .multistep import Multistep, NegateOnLossIncrease, Online, Sequential
25
+ from .regularization import Dropout, PerturbWeights, WeightDropout
26
+ from .split import Split
27
+ from .switch import Alternate, Switch