torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ from typing import Literal
6
6
 
7
7
  import torch
8
8
 
9
- from ...core import Modular, Module, Vars
9
+ from ...core import Modular, Module, Var
10
10
  from ...utils import NumberList, TensorList
11
11
  from ...utils.derivatives import jacobian_wrt
12
12
  from ..grad_approximation import GradApproximator, GradTarget
@@ -17,24 +17,24 @@ class Reformulation(Module, ABC):
17
17
  super().__init__(defaults)
18
18
 
19
19
  @abstractmethod
20
- def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], vars: Vars) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
20
+ def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
21
21
  """returns loss and gradient, if backward is False then gradient can be None"""
22
22
 
23
- def pre_step(self, vars: Vars) -> Vars | None:
23
+ def pre_step(self, var: Var) -> Var | None:
24
24
  """This runs once before each step, whereas `closure` may run multiple times per step if further modules
25
25
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
26
- return vars
26
+ return var
27
27
 
28
- def step(self, vars):
29
- ret = self.pre_step(vars)
30
- if isinstance(ret, Vars): vars = ret
28
+ def step(self, var):
29
+ ret = self.pre_step(var)
30
+ if isinstance(ret, Var): var = ret
31
31
 
32
- if vars.closure is None: raise RuntimeError("Reformulation requires closure")
33
- params, closure = vars.params, vars.closure
32
+ if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
+ params, closure = var.params, var.closure
34
34
 
35
35
 
36
36
  def modified_closure(backward=True):
37
- loss, grad = self.closure(backward, closure, params, vars)
37
+ loss, grad = self.closure(backward, closure, params, var)
38
38
 
39
39
  if grad is not None:
40
40
  for p,g in zip(params, grad):
@@ -42,8 +42,8 @@ class Reformulation(Module, ABC):
42
42
 
43
43
  return loss
44
44
 
45
- vars.closure = modified_closure
46
- return vars
45
+ var.closure = modified_closure
46
+ return var
47
47
 
48
48
 
49
49
  def _decay_sigma_(self: Module, params):
@@ -58,12 +58,46 @@ def _generate_perturbations_to_state_(self: Module, params: TensorList, n_sample
58
58
  for param, prt in zip(params, zip(*perturbations)):
59
59
  self.state[param]['perturbations'] = prt
60
60
 
61
- def _clear_state_hook(optimizer: Modular, vars: Vars, self: Module):
61
+ def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
62
62
  for m in optimizer.unrolled_modules:
63
63
  if m is not self:
64
64
  m.reset()
65
65
 
66
66
  class GaussianHomotopy(Reformulation):
67
+ """Approximately smoothes the function with a gaussian kernel by sampling it at random perturbed points around current point. Both function values and gradients are averaged over all samples. The perturbed points are generated before each
68
+ step and remain the same throughout the step.
69
+
70
+ .. note::
71
+ This module reformulates the objective, it modifies the closure to evaluate value and gradients of a smoothed function. All modules after this will operate on the modified objective.
72
+
73
+ .. note::
74
+ This module requires the a closure passed to the optimizer step,
75
+ as it needs to re-evaluate the loss and gradients at perturbed points.
76
+
77
+ Args:
78
+ n_samples (int): number of points to sample, larger values lead to a more accurate smoothing.
79
+ init_sigma (float): initial scale of perturbations.
80
+ tol (float | None, optional):
81
+ if maximal parameters change value is smaller than this, sigma is reduced by :code:`decay`. Defaults to 1e-4.
82
+ decay (float, optional): multiplier to sigma when converged on a smoothed function. Defaults to 0.5.
83
+ max_steps (int | None, optional): maximum number of steps before decaying sigma. Defaults to None.
84
+ clear_state (bool, optional):
85
+ whether to clear all other module states when sigma is decayed, because the objective function changes. Defaults to True.
86
+ seed (int | None, optional): seed for random perturbationss. Defaults to None.
87
+
88
+ Examples:
89
+ Gaussian-smoothed NewtonCG
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.GaussianHomotopy(100),
96
+ tz.m.NewtonCG(maxiter=20),
97
+ tz.m.AdaptiveBacktracking(),
98
+ )
99
+
100
+ """
67
101
  def __init__(
68
102
  self,
69
103
  n_samples: int,
@@ -85,12 +119,12 @@ class GaussianHomotopy(Reformulation):
85
119
  else: self.global_state['generator'] = None
86
120
  return self.global_state['generator']
87
121
 
88
- def pre_step(self, vars):
89
- params = TensorList(vars.params)
122
+ def pre_step(self, var):
123
+ params = TensorList(var.params)
90
124
  settings = self.settings[params[0]]
91
125
  n_samples = settings['n_samples']
92
- init_sigma = self.get_settings('init_sigma', params=params)
93
- sigmas = self.get_state('sigma', params = params, init=init_sigma)
126
+ init_sigma = [self.settings[p]['init_sigma'] for p in params]
127
+ sigmas = self.get_state(params, 'sigma', init=init_sigma)
94
128
 
95
129
  if any('perturbations' not in self.state[p] for p in params):
96
130
  generator = self._get_generator(settings['seed'], params)
@@ -109,9 +143,9 @@ class GaussianHomotopy(Reformulation):
109
143
  tol = settings['tol']
110
144
  if tol is not None and not decayed:
111
145
  if not any('prev_params' in self.state[p] for p in params):
112
- prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
146
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
113
147
  else:
114
- prev_params = self.get_state('prev_params', params=params, cls=TensorList, init='param')
148
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
115
149
  s = params - prev_params
116
150
 
117
151
  if s.abs().global_max() <= tol:
@@ -124,10 +158,10 @@ class GaussianHomotopy(Reformulation):
124
158
  generator = self._get_generator(settings['seed'], params)
125
159
  _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
126
160
  if settings['clear_state']:
127
- vars.post_step_hooks.append(partial(_clear_state_hook, self=self))
161
+ var.post_step_hooks.append(partial(_clear_state_hook, self=self))
128
162
 
129
163
  @torch.no_grad
130
- def closure(self, backward, closure, params, vars):
164
+ def closure(self, backward, closure, params, var):
131
165
  params = TensorList(params)
132
166
 
133
167
  settings = self.settings[params[0]]
@@ -56,7 +56,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
58
  class LaplacianSmoothing(Transform):
59
- """Applies laplacian smoothing via a fast Fourier transform solver.
59
+ """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
62
62
  sigma (float, optional): controls the amount of smoothing. Defaults to 1.
@@ -67,11 +67,21 @@ class LaplacianSmoothing(Transform):
67
67
  minimum number of elements in a parameter to apply laplacian smoothing to.
68
68
  Only has effect if `layerwise` is True. Defaults to 4.
69
69
  target (str, optional):
70
- what to set on vars.
70
+ what to set on var.
71
+
72
+ Examples:
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
+
75
+ .. code-block:: python
76
+
77
+ opt = tz.Modular(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
71
82
 
72
83
  Reference:
73
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
74
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
84
+ Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
75
85
 
76
86
  """
77
87
  def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
@@ -82,19 +92,17 @@ class LaplacianSmoothing(Transform):
82
92
 
83
93
 
84
94
  @torch.no_grad
85
- def transform(self, tensors, params, grads, vars):
86
- layerwise = self.settings[params[0]]['layerwise']
95
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
96
+ layerwise = settings[0]['layerwise']
87
97
 
88
98
  # layerwise laplacian smoothing
89
99
  if layerwise:
90
100
 
91
101
  # precompute the denominator for each layer and store it in each parameters state
92
102
  smoothed_target = TensorList()
93
- for p, t in zip(params, tensors):
94
- settings = self.settings[p]
95
- if p.numel() > settings['min_numel']:
96
- state = self.state[p]
97
- if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, settings['sigma'])
103
+ for p, t, state, setting in zip(params, tensors, states, settings):
104
+ if p.numel() > setting['min_numel']:
105
+ if 'denominator' not in state: state['denominator'] = _precompute_denominator(p, setting['sigma'])
98
106
  smoothed_target.append(torch.fft.ifft(torch.fft.fft(t.view(-1)) / state['denominator']).real.view_as(t)) #pylint:disable=not-callable
99
107
  else:
100
108
  smoothed_target.append(t)
@@ -106,7 +114,7 @@ class LaplacianSmoothing(Transform):
106
114
  # precompute full denominator
107
115
  tensors = TensorList(tensors)
108
116
  if self.global_state.get('full_denominator', None) is None:
109
- self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), self.settings[params[0]]['sigma'])
117
+ self.global_state['full_denominator'] = _precompute_denominator(tensors.to_vec(), settings[0]['sigma'])
110
118
 
111
119
  # apply the smoothing
112
120
  vec = tensors.to_vec()
@@ -0,0 +1,2 @@
1
+ from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein
@@ -0,0 +1,122 @@
1
+ """Various step size strategies"""
2
+ from typing import Any, Literal
3
+ from operator import itemgetter
4
+ import torch
5
+
6
+ from ...core import Transform, Chainable
7
+ from ...utils import TensorList, unpack_dicts, unpack_states, NumberList
8
+
9
+
10
+ class PolyakStepSize(Transform):
11
+ """Polyak's subgradient method.
12
+
13
+ Args:
14
+ f_star (int, optional):
15
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
+ max (float | None, optional): maximum possible step size. Defaults to None.
17
+ use_grad (bool, optional):
18
+ if True, uses dot product of update and gradient to compute the step size.
19
+ Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
20
+ Defaults to False.
21
+ alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
22
+ """
23
+ def __init__(self, f_star: float = 0, max: float | None = None, use_grad=False, alpha: float = 1, inner: Chainable | None = None):
24
+
25
+ defaults = dict(alpha=alpha, max=max, f_star=f_star, use_grad=use_grad)
26
+ super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
27
+
28
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
29
+ assert grads is not None and loss is not None
30
+ tensors = TensorList(tensors)
31
+ grads = TensorList(grads)
32
+
33
+ use_grad, max, f_star = itemgetter('use_grad', 'max', 'f_star')(settings[0])
34
+
35
+ if use_grad: gg = tensors.dot(grads)
36
+ else: gg = tensors.dot(tensors)
37
+
38
+ if gg.abs() <= torch.finfo(gg.dtype).eps: step_size = 0 # converged
39
+ else: step_size = (loss - f_star) / gg
40
+
41
+ if max is not None:
42
+ if step_size > max: step_size = max
43
+
44
+ self.global_state['step_size'] = step_size
45
+
46
+ @torch.no_grad
47
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
48
+ step_size = self.global_state.get('step_size', 1)
49
+ torch._foreach_mul_(tensors, step_size * unpack_dicts(settings, 'alpha', cls=NumberList))
50
+ return tensors
51
+
52
+
53
+
54
+ def _bb_short(s: TensorList, y: TensorList, sy, eps, fallback):
55
+ yy = y.dot(y)
56
+ if yy < eps:
57
+ if sy < eps: return fallback # try to fallback on long
58
+ ss = s.dot(s)
59
+ return ss/sy
60
+ return sy/yy
61
+
62
+ def _bb_long(s: TensorList, y: TensorList, sy, eps, fallback):
63
+ ss = s.dot(s)
64
+ if sy < eps:
65
+ yy = y.dot(y) # try to fallback on short
66
+ if yy < eps: return fallback
67
+ return sy/yy
68
+ return ss/sy
69
+
70
+ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback):
71
+ short = _bb_short(s, y, sy, eps, fallback)
72
+ long = _bb_long(s, y, sy, eps, fallback)
73
+ return (short * long) ** 0.5
74
+
75
+ class BarzilaiBorwein(Transform):
76
+ """Barzilai-Borwein method.
77
+
78
+ Args:
79
+ type (str, optional):
80
+ one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
81
+ Defaults to 'geom'.
82
+ scale_first (bool, optional):
83
+ whether to make first step very small when previous gradient is not available. Defaults to True.
84
+ fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
85
+ inner (Chainable | None, optional):
86
+ step size will be applied to outputs of this module. Defaults to None.
87
+
88
+ """
89
+ def __init__(self, type: Literal['long', 'short', 'geom'] = 'geom', scale_first:bool=True, fallback:float=1e-3, inner:Chainable|None = None):
90
+ defaults = dict(type=type, fallback=fallback)
91
+ super().__init__(defaults, uses_grad=False, scale_first=scale_first, inner=inner)
92
+
93
+ def reset_for_online(self):
94
+ super().reset_for_online()
95
+ self.clear_state_keys('prev_p', 'prev_g')
96
+
97
+ @torch.no_grad
98
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
99
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
100
+ fallback = unpack_dicts(settings, 'fallback', cls=NumberList)
101
+ type = settings[0]['type']
102
+
103
+ s = params-prev_p
104
+ y = tensors-prev_g
105
+ sy = s.dot(y)
106
+ eps = torch.finfo(sy.dtype).eps
107
+
108
+ if type == 'short': step_size = _bb_short(s, y, sy, eps, fallback)
109
+ elif type == 'long': step_size = _bb_long(s, y, sy, eps, fallback)
110
+ elif type == 'geom': step_size = _bb_geom(s, y, sy, eps, fallback)
111
+ else: raise ValueError(type)
112
+
113
+ self.global_state['step_size'] = step_size
114
+
115
+ prev_p.copy_(params)
116
+ prev_g.copy_(tensors)
117
+
118
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
119
+ step_size = self.global_state.get('step_size', 1)
120
+ torch._foreach_mul_(tensors, step_size)
121
+ return tensors
122
+
@@ -0,0 +1,154 @@
1
+ """Learning rate"""
2
+ import torch
3
+ import random
4
+
5
+ from ...core import Transform
6
+ from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
+
8
+ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
9
+ """multiplies by lr if lr is not 1"""
10
+ if generic_ne(lr, 1):
11
+ if inplace: return tensors.mul_(lr)
12
+ return tensors * lr
13
+ return tensors
14
+
15
+ class LR(Transform):
16
+ """Learning rate. Adding this module also adds support for LR schedulers."""
17
+ def __init__(self, lr: float):
18
+ defaults=dict(lr=lr)
19
+ super().__init__(defaults, uses_grad=False)
20
+
21
+ @torch.no_grad
22
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
23
+ return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
+
25
+ class StepSize(Transform):
26
+ """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
+ def __init__(self, step_size: float, key = 'step_size'):
28
+ defaults={"key": key, key: step_size}
29
+ super().__init__(defaults, uses_grad=False)
30
+
31
+ @torch.no_grad
32
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
33
+ return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
+
35
+
36
+ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
37
+ """returns warm up lr scalar"""
38
+ if step > steps: return end_lr
39
+ return start_lr + (end_lr - start_lr) * (step / steps)
40
+
41
+ class Warmup(Transform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
43
+
44
+ Args:
45
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
+ start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
47
+ end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
48
+
49
+ Example:
50
+ Adam with 1000 steps warmup
51
+
52
+ .. code-block:: python
53
+
54
+ opt = tz.Modular(
55
+ model.parameters(),
56
+ tz.m.Adam(),
57
+ tz.m.LR(1e-2),
58
+ tz.m.Warmup(steps=1000)
59
+ )
60
+
61
+ """
62
+ def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
63
+ defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
64
+ super().__init__(defaults, uses_grad=False)
65
+
66
+ @torch.no_grad
67
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
68
+ start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
+ num_steps = settings[0]['steps']
70
+ step = self.global_state.get('step', 0)
71
+
72
+ tensors = lazy_lr(
73
+ TensorList(tensors),
74
+ lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
75
+ inplace=True
76
+ )
77
+ self.global_state['step'] = step + 1
78
+ return tensors
79
+
80
+ class WarmupNormClip(Transform):
81
+ """Warmup via clipping of the update norm.
82
+
83
+ Args:
84
+ start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
85
+ end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
86
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
87
+
88
+ Example:
89
+ Adam with 1000 steps norm clip warmup
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.Adam(),
96
+ tz.m.WarmupNormClip(steps=1000)
97
+ tz.m.LR(1e-2),
98
+ )
99
+ """
100
+ def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
101
+ defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
102
+ super().__init__(defaults, uses_grad=False)
103
+
104
+ @torch.no_grad
105
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
106
+ start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
+ num_steps = settings[0]['steps']
108
+ step = self.global_state.get('step', 0)
109
+ if step > num_steps: return tensors
110
+
111
+ tensors = TensorList(tensors)
112
+ norm = tensors.global_vector_norm()
113
+ current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
114
+ if norm > current_max_norm:
115
+ tensors.mul_(current_max_norm / norm)
116
+
117
+ self.global_state['step'] = step + 1
118
+ return tensors
119
+
120
+
121
+ class RandomStepSize(Transform):
122
+ """Uses random global or layer-wise step size from `low` to `high`.
123
+
124
+ Args:
125
+ low (float, optional): minimum learning rate. Defaults to 0.
126
+ high (float, optional): maximum learning rate. Defaults to 1.
127
+ parameterwise (bool, optional):
128
+ if True, generate random step size for each parameter separately,
129
+ if False generate one global random step size. Defaults to False.
130
+ """
131
+ def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
132
+ defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
133
+ super().__init__(defaults, uses_grad=False)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ s = settings[0]
138
+ parameterwise = s['parameterwise']
139
+
140
+ seed = s['seed']
141
+ if 'generator' not in self.global_state:
142
+ self.global_state['generator'] = random.Random(seed)
143
+ generator: random.Random = self.global_state['generator']
144
+
145
+ if parameterwise:
146
+ low, high = unpack_dicts(settings, 'low', 'high')
147
+ lr = [generator.uniform(l, h) for l, h in zip(low, high)]
148
+ else:
149
+ low = s['low']
150
+ high = s['high']
151
+ lr = generator.uniform(low, high)
152
+
153
+ torch._foreach_mul_(tensors, lr)
154
+ return tensors
@@ -1 +1 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
@@ -1,9 +1,11 @@
1
1
  from collections.abc import Iterable, Sequence
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
6
  from ...core import Module, Target, Transform
6
- from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+
7
9
 
8
10
  @torch.no_grad
9
11
  def weight_decay_(
@@ -20,17 +22,126 @@ def weight_decay_(
20
22
 
21
23
 
22
24
  class WeightDecay(Transform):
25
+ """Weight decay.
26
+
27
+ Args:
28
+ weight_decay (float): weight decay scale.
29
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
30
+ target (Target, optional): what to set on var. Defaults to 'update'.
31
+
32
+ Examples:
33
+ Adam with non-decoupled weight decay
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ tz.m.WeightDecay(1e-3),
40
+ tz.m.Adam(),
41
+ tz.m.LR(1e-3)
42
+ )
43
+
44
+ Adam with decoupled weight decay that still scales with learning rate
45
+
46
+ .. code-block:: python
47
+
48
+ opt = tz.Modular(
49
+ model.parameters(),
50
+ tz.m.Adam(),
51
+ tz.m.WeightDecay(1e-3),
52
+ tz.m.LR(1e-3)
53
+ )
54
+
55
+ Adam with fully decoupled weight decay that doesn't scale with learning rate
56
+
57
+ .. code-block:: python
58
+
59
+ opt = tz.Modular(
60
+ model.parameters(),
61
+ tz.m.Adam(),
62
+ tz.m.LR(1e-3),
63
+ tz.m.WeightDecay(1e-6)
64
+ )
65
+
66
+ """
23
67
  def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
68
+
24
69
  defaults = dict(weight_decay=weight_decay, ord=ord)
25
70
  super().__init__(defaults, uses_grad=False, target=target)
26
71
 
27
72
  @torch.no_grad
28
- def transform(self, tensors, params, grads, vars):
29
- weight_decay = self.get_settings('weight_decay', params=params, cls=NumberList)
30
- ord = self.settings[params[0]]['ord']
73
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
74
+ weight_decay = NumberList(s['weight_decay'] for s in settings)
75
+ ord = settings[0]['ord']
31
76
 
32
77
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
33
78
 
79
+ class RelativeWeightDecay(Transform):
80
+ """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of :code:`norm_input` argument.
81
+
82
+ Args:
83
+ weight_decay (float): relative weight decay scale.
84
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
85
+ norm_input (str, optional):
86
+ determines what should weight decay be relative to. "update", "grad" or "params".
87
+ Defaults to "update".
88
+ target (Target, optional): what to set on var. Defaults to 'update'.
89
+
90
+ Examples:
91
+ Adam with non-decoupled relative weight decay
92
+
93
+ .. code-block:: python
94
+
95
+ opt = tz.Modular(
96
+ model.parameters(),
97
+ tz.m.RelativeWeightDecay(1e-3),
98
+ tz.m.Adam(),
99
+ tz.m.LR(1e-3)
100
+ )
101
+
102
+ Adam with decoupled relative weight decay
103
+
104
+ .. code-block:: python
105
+
106
+ opt = tz.Modular(
107
+ model.parameters(),
108
+ tz.m.Adam(),
109
+ tz.m.RelativeWeightDecay(1e-3),
110
+ tz.m.LR(1e-3)
111
+ )
112
+
113
+ """
114
+ def __init__(
115
+ self,
116
+ weight_decay: float = 0.1,
117
+ ord: int = 2,
118
+ norm_input: Literal["update", "grad", "params"] = "update",
119
+ target: Target = "update",
120
+ ):
121
+ defaults = dict(weight_decay=weight_decay, ord=ord, norm_input=norm_input)
122
+ super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
123
+
124
+ @torch.no_grad
125
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
126
+ weight_decay = NumberList(s['weight_decay'] for s in settings)
127
+
128
+ ord = settings[0]['ord']
129
+ norm_input = settings[0]['norm_input']
130
+
131
+ if norm_input == 'update': src = TensorList(tensors)
132
+ elif norm_input == 'grad':
133
+ assert grads is not None
134
+ src = TensorList(grads)
135
+ elif norm_input == 'params':
136
+ src = TensorList(params)
137
+ else:
138
+ raise ValueError(norm_input)
139
+
140
+ mean_abs = src.abs().global_mean()
141
+
142
+ return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * mean_abs, ord)
143
+
144
+
34
145
  @torch.no_grad
35
146
  def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberList, ord:int=2):
36
147
  """directly decays weights in-place"""
@@ -38,15 +149,20 @@ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberL
38
149
  weight_decay_(params, params, -weight_decay, ord)
39
150
 
40
151
  class DirectWeightDecay(Module):
41
- """directly decays weights in-place"""
152
+ """Directly applies weight decay to parameters.
153
+
154
+ Args:
155
+ weight_decay (float): weight decay scale.
156
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
157
+ """
42
158
  def __init__(self, weight_decay: float, ord: int = 2,):
43
159
  defaults = dict(weight_decay=weight_decay, ord=ord)
44
160
  super().__init__(defaults)
45
161
 
46
162
  @torch.no_grad
47
- def step(self, vars):
48
- weight_decay = self.get_settings('weight_decay', params=vars.params, cls=NumberList)
49
- ord = self.settings[vars.params[0]]['ord']
163
+ def step(self, var):
164
+ weight_decay = self.get_settings(var.params, 'weight_decay', cls=NumberList)
165
+ ord = self.settings[var.params[0]]['ord']
50
166
 
51
- decay_weights_(vars.params, weight_decay, ord)
52
- return vars
167
+ decay_weights_(var.params, weight_decay, ord)
168
+ return var