torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ from typing import Any
8
8
  import numpy as np
9
9
  import torch
10
10
 
11
- from ...core import Module, Target, Vars
11
+ from ...core import Module, Target, Var
12
12
  from ...utils import tofloat
13
13
 
14
14
 
@@ -62,12 +62,12 @@ class LineSearch(Module, ABC):
62
62
  if any(a!=0 for a in alpha):
63
63
  torch._foreach_add_(params, torch._foreach_mul(update, alpha))
64
64
 
65
- def _loss(self, step_size: float, vars: Vars, closure, params: list[torch.Tensor],
65
+ def _loss(self, step_size: float, var: Var, closure, params: list[torch.Tensor],
66
66
  update: list[torch.Tensor], backward:bool=False) -> float:
67
67
 
68
68
  # if step_size is 0, we might already know the loss
69
- if (vars.loss is not None) and (step_size == 0):
70
- return tofloat(vars.loss)
69
+ if (var.loss is not None) and (step_size == 0):
70
+ return tofloat(var.loss)
71
71
 
72
72
  # check max iter
73
73
  if self._maxiter is not None and self._current_iter >= self._maxiter: raise MaxLineSearchItersReached
@@ -85,23 +85,23 @@ class LineSearch(Module, ABC):
85
85
  self._lowest_loss = tofloat(loss)
86
86
  self._best_step_size = step_size
87
87
 
88
- # if evaluated loss at step size 0, set it to vars.loss
88
+ # if evaluated loss at step size 0, set it to var.loss
89
89
  if step_size == 0:
90
- vars.loss = loss
91
- if backward: vars.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
90
+ var.loss = loss
91
+ if backward: var.grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
92
92
 
93
93
  return tofloat(loss)
94
94
 
95
- def _loss_derivative(self, step_size: float, vars: Vars, closure,
95
+ def _loss_derivative(self, step_size: float, var: Var, closure,
96
96
  params: list[torch.Tensor], update: list[torch.Tensor]):
97
97
  # if step_size is 0, we might already know the derivative
98
- if (vars.grad is not None) and (step_size == 0):
99
- loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=False)
100
- derivative = - sum(t.sum() for t in torch._foreach_mul(vars.grad, update))
98
+ if (var.grad is not None) and (step_size == 0):
99
+ loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=False)
100
+ derivative = - sum(t.sum() for t in torch._foreach_mul(var.grad, update))
101
101
 
102
102
  else:
103
103
  # loss with a backward pass sets params.grad
104
- loss = self._loss(step_size=step_size,vars=vars,closure=closure,params=params,update=update,backward=True)
104
+ loss = self._loss(step_size=step_size,var=var,closure=closure,params=params,update=update,backward=True)
105
105
 
106
106
  # directional derivative
107
107
  derivative = - sum(t.sum() for t in torch._foreach_mul([p.grad if p.grad is not None
@@ -109,60 +109,60 @@ class LineSearch(Module, ABC):
109
109
 
110
110
  return loss, tofloat(derivative)
111
111
 
112
- def evaluate_step_size(self, step_size: float, vars: Vars, backward:bool=False):
113
- closure = vars.closure
112
+ def evaluate_step_size(self, step_size: float, var: Var, backward:bool=False):
113
+ closure = var.closure
114
114
  if closure is None: raise RuntimeError('line search requires closure')
115
- return self._loss(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update(),backward=backward)
115
+ return self._loss(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update(),backward=backward)
116
116
 
117
- def evaluate_step_size_loss_and_derivative(self, step_size: float, vars: Vars):
118
- closure = vars.closure
117
+ def evaluate_step_size_loss_and_derivative(self, step_size: float, var: Var):
118
+ closure = var.closure
119
119
  if closure is None: raise RuntimeError('line search requires closure')
120
- return self._loss_derivative(step_size=step_size, vars=vars, closure=closure, params=vars.params,update=vars.get_update())
120
+ return self._loss_derivative(step_size=step_size, var=var, closure=closure, params=var.params,update=var.get_update())
121
121
 
122
- def make_objective(self, vars: Vars, backward:bool=False):
123
- closure = vars.closure
122
+ def make_objective(self, var: Var, backward:bool=False):
123
+ closure = var.closure
124
124
  if closure is None: raise RuntimeError('line search requires closure')
125
- return partial(self._loss, vars=vars, closure=closure, params=vars.params, update=vars.get_update(), backward=backward)
125
+ return partial(self._loss, var=var, closure=closure, params=var.params, update=var.get_update(), backward=backward)
126
126
 
127
- def make_objective_with_derivative(self, vars: Vars):
128
- closure = vars.closure
127
+ def make_objective_with_derivative(self, var: Var):
128
+ closure = var.closure
129
129
  if closure is None: raise RuntimeError('line search requires closure')
130
- return partial(self._loss_derivative, vars=vars, closure=closure, params=vars.params, update=vars.get_update())
130
+ return partial(self._loss_derivative, var=var, closure=closure, params=var.params, update=var.get_update())
131
131
 
132
132
  @abstractmethod
133
- def search(self, update: list[torch.Tensor], vars: Vars) -> float:
133
+ def search(self, update: list[torch.Tensor], var: Var) -> float:
134
134
  """Finds the step size to use"""
135
135
 
136
136
  @torch.no_grad
137
- def step(self, vars: Vars) -> Vars:
137
+ def step(self, var: Var) -> Var:
138
138
  self._reset()
139
- params = vars.params
140
- update = vars.get_update()
139
+ params = var.params
140
+ update = var.get_update()
141
141
 
142
142
  try:
143
- step_size = self.search(update=update, vars=vars)
143
+ step_size = self.search(update=update, var=var)
144
144
  except MaxLineSearchItersReached:
145
145
  step_size = self._best_step_size
146
146
 
147
147
  # set loss_approx
148
- if vars.loss_approx is None: vars.loss_approx = self._lowest_loss
148
+ if var.loss_approx is None: var.loss_approx = self._lowest_loss
149
149
 
150
150
  # this is last module - set step size to found step_size times lr
151
- if vars.is_last:
151
+ if var.is_last:
152
152
 
153
- if vars.last_module_lrs is None:
153
+ if var.last_module_lrs is None:
154
154
  self.set_step_size_(step_size, params=params, update=update)
155
155
 
156
156
  else:
157
- self._set_per_parameter_step_size_([step_size*lr for lr in vars.last_module_lrs], params=params, update=update)
157
+ self._set_per_parameter_step_size_([step_size*lr for lr in var.last_module_lrs], params=params, update=update)
158
158
 
159
- vars.stop = True; vars.skip_update = True
160
- return vars
159
+ var.stop = True; var.skip_update = True
160
+ return var
161
161
 
162
162
  # revert parameters and multiply update by step size
163
163
  self.set_step_size_(0, params=params, update=update)
164
- torch._foreach_mul_(vars.update, step_size)
165
- return vars
164
+ torch._foreach_mul_(var.update, step_size)
165
+ return var
166
166
 
167
167
 
168
168
  class GridLineSearch(LineSearch):
@@ -172,10 +172,10 @@ class GridLineSearch(LineSearch):
172
172
  super().__init__(defaults)
173
173
 
174
174
  @torch.no_grad
175
- def search(self, update, vars):
176
- start,end,num=itemgetter('start','end','num')(self.settings[vars.params[0]])
175
+ def search(self, update, var):
176
+ start,end,num=itemgetter('start','end','num')(self.settings[var.params[0]])
177
177
 
178
178
  for lr in torch.linspace(start,end,num):
179
- self.evaluate_step_size(lr.item(), vars=vars, backward=False)
179
+ self.evaluate_step_size(lr.item(), var=var, backward=False)
180
180
 
181
181
  return self._best_step_size
@@ -7,6 +7,21 @@ from .line_search import LineSearch
7
7
 
8
8
 
9
9
  class ScipyMinimizeScalar(LineSearch):
10
+ """Line search via :code:`scipy.optimize.minimize_scalar` which implements brent, golden search and bounded brent methods.
11
+
12
+ Args:
13
+ method (str | None, optional): "brent", "golden" or "bounded". Defaults to None.
14
+ maxiter (int | None, optional): maximum number of function evaluations the line search is allowed to perform. Defaults to None.
15
+ bracket (Sequence | None, optional):
16
+ Either a triple (xa, xb, xc) satisfying xa < xb < xc and func(xb) < func(xa) and func(xb) < func(xc), or a pair (xa, xb) to be used as initial points for a downhill bracket search. Defaults to None.
17
+ bounds (Sequence | None, optional):
18
+ For method ‘bounded’, bounds is mandatory and must have two finite items corresponding to the optimization bounds. Defaults to None.
19
+ tol (float | None, optional): Tolerance for termination. Defaults to None.
20
+ options (dict | None, optional): A dictionary of solver options. Defaults to None.
21
+
22
+ For more details on methods and arguments refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize_scalar.html
23
+
24
+ """
10
25
  def __init__(
11
26
  self,
12
27
  method: str | None = None,
@@ -24,10 +39,10 @@ class ScipyMinimizeScalar(LineSearch):
24
39
 
25
40
 
26
41
  @torch.no_grad
27
- def search(self, update, vars):
28
- objective = self.make_objective(vars=vars)
42
+ def search(self, update, var):
43
+ objective = self.make_objective(var=var)
29
44
  method, bracket, bounds, tol, options, maxiter = itemgetter(
30
- 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[vars.params[0]])
45
+ 'method', 'bracket', 'bounds', 'tol', 'options', 'maxiter')(self.settings[var.params[0]])
31
46
 
32
47
  if maxiter is not None:
33
48
  options = dict(options) if isinstance(options, Mapping) else {}
@@ -183,6 +183,21 @@ def _notfinite(x):
183
183
  return not math.isfinite(x)
184
184
 
185
185
  class StrongWolfe(LineSearch):
186
+ """Cubic interpolation line search satisfying Strong Wolfe condition.
187
+
188
+ Args:
189
+ init (float, optional): Initial step size. Defaults to 1.0.
190
+ c1 (float, optional): Acceptance value for weak wolfe condition. Defaults to 1e-4.
191
+ c2 (float, optional): Acceptance value for strong wolfe condition (set to 0.1 for conjugate gradient). Defaults to 0.9.
192
+ maxiter (int, optional): Maximum number of line search iterations. Defaults to 25.
193
+ maxzoom (int, optional): Maximum number of zoom iterations. Defaults to 10.
194
+ expand (float, optional): Expansion factor (multipler to step size when weak condition not satisfied). Defaults to 2.0.
195
+ adaptive (bool, optional):
196
+ when enabled, if line search failed, initial step size is reduced.
197
+ Otherwise it is reset to initial value. Defaults to True.
198
+ plus_minus (bool, optional):
199
+ If enabled and the direction is not descent direction, performs line search in opposite direction. Defaults to False.
200
+ """
186
201
  def __init__(
187
202
  self,
188
203
  init: float = 1.0,
@@ -193,23 +208,22 @@ class StrongWolfe(LineSearch):
193
208
  # a_max: float = 1e10,
194
209
  expand: float = 2.0,
195
210
  adaptive = True,
196
- fallback = False,
197
211
  plus_minus = False,
198
212
  ):
199
213
  defaults=dict(init=init,c1=c1,c2=c2,maxiter=maxiter,maxzoom=maxzoom,
200
- expand=expand, adaptive=adaptive, fallback=fallback, plus_minus=plus_minus)
214
+ expand=expand, adaptive=adaptive, plus_minus=plus_minus)
201
215
  super().__init__(defaults=defaults)
202
216
 
203
217
  self.global_state['initial_scale'] = 1.0
204
218
  self.global_state['beta_scale'] = 1.0
205
219
 
206
220
  @torch.no_grad
207
- def search(self, update, vars):
208
- objective = self.make_objective_with_derivative(vars=vars)
221
+ def search(self, update, var):
222
+ objective = self.make_objective_with_derivative(var=var)
209
223
 
210
- init, c1, c2, maxiter, maxzoom, expand, adaptive, fallback, plus_minus = itemgetter(
224
+ init, c1, c2, maxiter, maxzoom, expand, adaptive, plus_minus = itemgetter(
211
225
  'init', 'c1', 'c2', 'maxiter', 'maxzoom',
212
- 'expand', 'adaptive', 'fallback', 'plus_minus')(self.settings[vars.params[0]])
226
+ 'expand', 'adaptive', 'plus_minus')(self.settings[var.params[0]])
213
227
 
214
228
  f_0, g_0 = objective(0)
215
229
 
@@ -232,29 +246,4 @@ class StrongWolfe(LineSearch):
232
246
 
233
247
  # fallback to backtracking on fail
234
248
  if adaptive: self.global_state['initial_scale'] *= 0.5
235
- if not fallback: return 0
236
-
237
- objective = self.make_objective(vars=vars)
238
-
239
- # # directional derivative
240
- g_0 = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
241
-
242
- step_size = backtracking_line_search(
243
- objective,
244
- g_0,
245
- init=init * self.global_state["initial_scale"],
246
- beta=0.5 * self.global_state["beta_scale"],
247
- c=c1,
248
- maxiter=maxiter * 2,
249
- a_min=None,
250
- try_negative=plus_minus,
251
- )
252
-
253
- # found an alpha that reduces loss
254
- if step_size is not None:
255
- self.global_state['beta_scale'] = min(1.0, self.global_state.get('beta_scale', 1) * math.sqrt(1.5))
256
- return step_size
257
-
258
- # on fail reduce beta scale value
259
- self.global_state['beta_scale'] /= 1.5
260
- return 0
249
+ return 0
@@ -6,31 +6,43 @@ from .line_search import LineSearch
6
6
 
7
7
 
8
8
  class TrustRegion(LineSearch):
9
- """Basic first order trust region, re-evaluates closure with updated parameters and scales step size based on function value change"""
9
+ """Basic first order trust region method. Re-evaluates the function after stepping, if value decreased sufficiently,
10
+ step size is increased. If value increased, step size is decreased. This is prone to collapsing.
11
+
12
+ Args:
13
+ nplus (float, optional): multiplier to step size on successful steps. Defaults to 1.5.
14
+ nminus (float, optional): multiplier to step size on unsuccessful steps. Defaults to 0.75.
15
+ c (float, optional): descent condition. Defaults to 1e-4.
16
+ init (float, optional): initial step size. Defaults to 1.
17
+ backtrack (bool, optional): whether to undo the step if value increased. Defaults to True.
18
+ adaptive (bool, optional):
19
+ If enabled, when multiple consecutive steps have been successful or unsuccessful,
20
+ the corresponding multipliers are increased, otherwise they are reset. Defaults to True.
21
+ """
10
22
  def __init__(self, nplus: float=1.5, nminus: float=0.75, c: float=1e-4, init: float = 1, backtrack: bool = True, adaptive: bool = True):
11
23
  defaults = dict(nplus=nplus, nminus=nminus, c=c, init=init, backtrack=backtrack, adaptive=adaptive)
12
24
  super().__init__(defaults)
13
25
 
14
26
  @torch.no_grad
15
- def search(self, update, vars):
27
+ def search(self, update, var):
16
28
 
17
- nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[vars.params[0]])
29
+ nplus, nminus, c, init, backtrack, adaptive = itemgetter('nplus','nminus','c','init','backtrack', 'adaptive')(self.settings[var.params[0]])
18
30
  step_size = self.global_state.setdefault('step_size', init)
19
31
  previous_success = self.global_state.setdefault('previous_success', False)
20
32
  nplus_mul = self.global_state.setdefault('nplus_mul', 1)
21
33
  nminus_mul = self.global_state.setdefault('nminus_mul', 1)
22
34
 
23
35
 
24
- f_0 = self.evaluate_step_size(0, vars, backward=False)
36
+ f_0 = self.evaluate_step_size(0, var, backward=False)
25
37
 
26
38
  # directional derivative (0 if c = 0 because it is not needed)
27
39
  if c == 0: d = 0
28
- else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
40
+ else: d = -sum(t.sum() for t in torch._foreach_mul(var.get_grad(), update))
29
41
 
30
42
  # test step size
31
43
  sufficient_f = f_0 + c * step_size * min(d, 0) # pyright:ignore[reportArgumentType]
32
44
 
33
- f_1 = self.evaluate_step_size(step_size, vars, backward=False)
45
+ f_1 = self.evaluate_step_size(step_size, var, backward=False)
34
46
 
35
47
  proposed = step_size
36
48
 
@@ -1,2 +1,2 @@
1
1
  from .lr import LR, StepSize, Warmup
2
- from .step_size import PolyakStepSize, RandomStepSize
2
+ from .adaptive import PolyakStepSize, RandomStepSize
@@ -1,18 +1,20 @@
1
+ """Various step size strategies"""
1
2
  import random
2
3
  from typing import Any
3
-
4
+ from operator import itemgetter
4
5
  import torch
5
6
 
6
7
  from ...core import Transform
7
- from ...utils import TensorList, NumberList
8
+ from ...utils import TensorList, NumberList, unpack_dicts
8
9
 
9
10
 
10
11
  class PolyakStepSize(Transform):
11
- """Polyak step-size.
12
+ """Polyak's step-size method.
12
13
 
13
14
  Args:
14
15
  max (float | None, optional): maximum possible step size. Defaults to None.
15
- min_obj_value (int, optional): (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
+ min_obj_value (int, optional):
17
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
18
  use_grad (bool, optional):
17
19
  if True, uses dot product of update and gradient to compute the step size.
18
20
  Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
@@ -28,29 +30,24 @@ class PolyakStepSize(Transform):
28
30
  super().__init__(defaults, uses_grad=use_grad)
29
31
 
30
32
  @torch.no_grad
31
- def transform(self, tensors, params, grads, vars):
32
- loss = vars.get_loss(False)
33
+ def apply(self, tensors, params, grads, loss, states, settings):
33
34
  assert grads is not None
34
35
  tensors = TensorList(tensors)
35
36
  grads = TensorList(grads)
36
- alpha = self.get_settings('alpha', params=params, cls=NumberList)
37
- settings = self.settings[params[0]]
38
- parameterwise = settings['parameterwise']
39
- use_grad = settings['use_grad']
40
- max = settings['max']
41
- min_obj_value = settings['min_obj_value']
37
+ alpha = NumberList(s['alpha'] for s in settings)
38
+
39
+ parameterwise, use_grad, max, min_obj_value = itemgetter('parameterwise', 'use_grad', 'max', 'min_obj_value')(settings[0])
40
+
41
+ if use_grad: denom = tensors.dot(grads)
42
+ else: denom = tensors.dot(tensors)
42
43
 
43
44
  if parameterwise:
44
- if use_grad: denom = (tensors * grads).sum()
45
- else: denom = tensors.pow(2).sum()
46
45
  polyak_step_size: TensorList | Any = (loss - min_obj_value) / denom.where(denom!=0, 1)
47
46
  polyak_step_size = polyak_step_size.where(denom != 0, 0)
48
47
  if max is not None: polyak_step_size = polyak_step_size.clamp_max(max)
49
48
 
50
49
  else:
51
- if use_grad: denom = tensors.dot(grads)
52
- else: denom = tensors.dot(tensors)
53
- if denom == 0: polyak_step_size = 0 # we converged
50
+ if denom.abs() <= torch.finfo(denom.dtype).eps: polyak_step_size = 0 # converged
54
51
  else: polyak_step_size = (loss - min_obj_value) / denom
55
52
 
56
53
  if max is not None:
@@ -60,9 +57,8 @@ class PolyakStepSize(Transform):
60
57
  return tensors
61
58
 
62
59
 
63
-
64
60
  class RandomStepSize(Transform):
65
- """Uses random global step size from `low` to `high`.
61
+ """Uses random global or layer-wise step size from `low` to `high`.
66
62
 
67
63
  Args:
68
64
  low (float, optional): minimum learning rate. Defaults to 0.
@@ -76,21 +72,21 @@ class RandomStepSize(Transform):
76
72
  super().__init__(defaults, uses_grad=False)
77
73
 
78
74
  @torch.no_grad
79
- def transform(self, tensors, params, grads, vars):
80
- settings = self.settings[params[0]]
81
- parameterwise = settings['parameterwise']
75
+ def apply(self, tensors, params, grads, loss, states, settings):
76
+ s = settings[0]
77
+ parameterwise = s['parameterwise']
82
78
 
83
- seed = settings['seed']
79
+ seed = s['seed']
84
80
  if 'generator' not in self.global_state:
85
81
  self.global_state['generator'] = random.Random(seed)
86
82
  generator: random.Random = self.global_state['generator']
87
83
 
88
84
  if parameterwise:
89
- low, high = self.get_settings('low', 'high', params=params)
85
+ low, high = unpack_dicts(settings, 'low', 'high')
90
86
  lr = [generator.uniform(l, h) for l, h in zip(low, high)]
91
87
  else:
92
- low = settings['low']
93
- high = settings['high']
88
+ low = s['low']
89
+ high = s['high']
94
90
  lr = generator.uniform(low, high)
95
91
 
96
92
  torch._foreach_mul_(tensors, lr)
@@ -1,8 +1,8 @@
1
+ """Learning rate"""
1
2
  import torch
2
3
 
3
4
  from ...core import Transform
4
- from ...utils import NumberList, TensorList, generic_eq
5
-
5
+ from ...utils import NumberList, TensorList, generic_eq, unpack_dicts
6
6
 
7
7
  def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
8
8
  """multiplies by lr if lr is not 1"""
@@ -11,48 +11,52 @@ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
11
11
  return tensors * lr
12
12
 
13
13
  class LR(Transform):
14
+ """Learning rate. Adding this module also adds support for LR schedulers."""
14
15
  def __init__(self, lr: float):
15
16
  defaults=dict(lr=lr)
16
17
  super().__init__(defaults, uses_grad=False)
17
18
 
18
19
  @torch.no_grad
19
- def transform(self, tensors, params, grads, vars):
20
- return lazy_lr(TensorList(tensors), lr=self.get_settings('lr', params=params), inplace=True)
20
+ def apply(self, tensors, params, grads, loss, states, settings):
21
+ return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
21
22
 
22
23
  class StepSize(Transform):
23
- """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name"""
24
+ """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
24
25
  def __init__(self, step_size: float, key = 'step_size'):
25
26
  defaults={"key": key, key: step_size}
26
27
  super().__init__(defaults, uses_grad=False)
27
28
 
28
29
  @torch.no_grad
29
- def transform(self, tensors, params, grads, vars):
30
- lrs = []
31
- for p in params:
32
- settings = self.settings[p]
33
- lrs.append(settings[settings['key']])
34
- return lazy_lr(TensorList(tensors), lr=lrs, inplace=True)
30
+ def apply(self, tensors, params, grads, loss, states, settings):
31
+ return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
35
32
 
36
33
 
37
- def warmup(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
34
+ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
38
35
  """returns warm up lr scalar"""
39
36
  if step > steps: return end_lr
40
37
  return start_lr + (end_lr - start_lr) * (step / steps)
41
38
 
42
39
  class Warmup(Transform):
40
+ """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
41
+
42
+ Args:
43
+ start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
44
+ end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
45
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
+ """
43
47
  def __init__(self, start_lr = 1e-5, end_lr:float = 1, steps = 100):
44
48
  defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
45
49
  super().__init__(defaults, uses_grad=False)
46
50
 
47
51
  @torch.no_grad
48
- def transform(self, tensors, params, grads, vars):
49
- start_lr, end_lr = self.get_settings('start_lr', 'end_lr', params=params, cls = NumberList)
50
- num_steps = self.settings[params[0]]['steps']
52
+ def apply(self, tensors, params, grads, loss, states, settings):
53
+ start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
54
+ num_steps = settings[0]['steps']
51
55
  step = self.global_state.get('step', 0)
52
56
 
53
57
  target = lazy_lr(
54
58
  TensorList(tensors),
55
- lr=warmup(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
59
+ lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
56
60
  inplace=True
57
61
  )
58
62
  self.global_state['step'] = step + 1
@@ -1,3 +1,4 @@
1
+ """Modules that perform averaging over a history of past updates."""
1
2
  from collections import deque
2
3
  from collections.abc import Sequence
3
4
  from typing import Any, Literal, cast
@@ -9,14 +10,19 @@ from ...utils import tolist
9
10
 
10
11
 
11
12
  class Averaging(TensorwiseTransform):
13
+ """Average of past :code:`history_size` updates.
14
+
15
+ Args:
16
+ history_size (int): Number of past updates to average
17
+ target (Target, optional): target. Defaults to 'update'.
18
+ """
12
19
  def __init__(self, history_size: int, target: Target = 'update'):
13
20
  defaults = dict(history_size=history_size)
14
21
  super().__init__(uses_grad=False, defaults=defaults, target=target)
15
22
 
16
23
  @torch.no_grad
17
- def transform(self, tensor, param, grad, vars):
18
- history_size = self.settings[param]['history_size']
19
- state = self.state[param]
24
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
25
+ history_size = settings['history_size']
20
26
  if 'history' not in state:
21
27
  state['history'] = deque(maxlen=history_size)
22
28
  state['average'] = torch.zeros_like(tensor)
@@ -29,15 +35,19 @@ class Averaging(TensorwiseTransform):
29
35
  return average / len(history)
30
36
 
31
37
  class WeightedAveraging(TensorwiseTransform):
32
- """weights are oldest to newest"""
38
+ """Weighted average of past :code:`len(weights)` updates.
39
+
40
+ Args:
41
+ weights (Sequence[float]): a sequence of weights from oldest to newest.
42
+ target (Target, optional): target. Defaults to 'update'.
43
+ """
33
44
  def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
34
45
  defaults = dict(weights = tolist(weights))
35
46
  super().__init__(uses_grad=False, defaults=defaults, target=target)
36
47
 
37
48
  @torch.no_grad
38
- def transform(self, tensor, param, grad, vars):
39
- weights = self.settings[param]['weights']
40
- state = self.state[param]
49
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
50
+ weights = settings['weights']
41
51
 
42
52
  if 'history' not in state:
43
53
  state['history'] = deque(maxlen=len(weights))
@@ -59,14 +69,19 @@ class WeightedAveraging(TensorwiseTransform):
59
69
 
60
70
 
61
71
  class MedianAveraging(TensorwiseTransform):
72
+ """Median of past :code:`history_size` updates.
73
+
74
+ Args:
75
+ history_size (int): Number of past updates to average
76
+ target (Target, optional): target. Defaults to 'update'.
77
+ """
62
78
  def __init__(self, history_size: int, target: Target = 'update'):
63
79
  defaults = dict(history_size = history_size)
64
80
  super().__init__(uses_grad=False, defaults=defaults, target=target)
65
81
 
66
82
  @torch.no_grad
67
- def transform(self, tensor, param, grad, vars):
68
- history_size = self.settings[param]['history_size']
69
- state = self.state[param]
83
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
84
+ history_size = settings['history_size']
70
85
 
71
86
  if 'history' not in state:
72
87
  state['history'] = deque(maxlen=history_size)