torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -10,12 +10,60 @@ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import nystrom_sketch_and_solve, nystrom_pcg
11
11
 
12
12
  class NystromSketchAndSolve(Module):
13
+ """Newton's method with a Nyström sketch-and-solve solver.
14
+
15
+ .. note::
16
+ This module requires the a closure passed to the optimizer step,
17
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
18
+ The closure must accept a ``backward`` argument (refer to documentation).
19
+
20
+ .. note::
21
+ In most cases NystromSketchAndSolve should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
22
+
23
+ .. note::
24
+ If this is unstable, increase the :code:`reg` parameter and tune the rank.
25
+
26
+ .. note:
27
+ :code:`tz.m.NystromPCG` usually outperforms this.
28
+
29
+ Args:
30
+ rank (int): size of the sketch, this many hessian-vector products will be evaluated per step.
31
+ reg (float, optional): regularization parameter. Defaults to 1e-3.
32
+ hvp_method (str, optional):
33
+ Determines how Hessian-vector products are evaluated.
34
+
35
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
+ This requires creating a graph for the gradient.
37
+ - ``"forward"``: Use a forward finite difference formula to
38
+ approximate the HVP. This requires one extra gradient evaluation.
39
+ - ``"central"``: Use a central finite difference formula for a
40
+ more accurate HVP approximation. This requires two extra
41
+ gradient evaluations.
42
+ Defaults to "autograd".
43
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
44
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
45
+ seed (int | None, optional): seed for random generator. Defaults to None.
46
+
47
+ Examples:
48
+ NystromSketchAndSolve with backtracking line search
49
+
50
+ .. code-block:: python
51
+
52
+ opt = tz.Modular(
53
+ model.parameters(),
54
+ tz.m.NystromSketchAndSolve(10),
55
+ tz.m.Backtracking()
56
+ )
57
+
58
+ Reference:
59
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
60
+ """
13
61
  def __init__(
14
62
  self,
15
63
  rank: int,
16
64
  reg: float = 1e-3,
17
65
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
18
- h=1e-3,
66
+ h: float = 1e-3,
19
67
  inner: Chainable | None = None,
20
68
  seed: int | None = None,
21
69
  ):
@@ -86,6 +134,61 @@ class NystromSketchAndSolve(Module):
86
134
 
87
135
 
88
136
  class NystromPCG(Module):
137
+ """Newton's method with a Nyström-preconditioned conjugate gradient solver.
138
+ This tends to outperform NewtonCG but requires tuning sketch size.
139
+ An adaptive version exists in https://arxiv.org/abs/2110.02820, I might implement it too at some point.
140
+
141
+ .. note::
142
+ This module requires the a closure passed to the optimizer step,
143
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
144
+ The closure must accept a ``backward`` argument (refer to documentation).
145
+
146
+ .. note::
147
+ In most cases NystromPCG should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
148
+
149
+ Args:
150
+ sketch_size (int):
151
+ size of the sketch for preconditioning, this many hessian-vector products will be evaluated before
152
+ running the conjugate gradient solver. Larger value improves the preconditioning and speeds up
153
+ conjugate gradient.
154
+ maxiter (int | None, optional):
155
+ maximum number of iterations. By default this is set to the number of dimensions
156
+ in the objective function, which is supposed to be enough for conjugate gradient
157
+ to have guaranteed convergence. Setting this to a small value can still generate good enough directions.
158
+ Defaults to None.
159
+ tol (float, optional): relative tolerance for conjugate gradient solver. Defaults to 1e-4.
160
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
161
+ hvp_method (str, optional):
162
+ Determines how Hessian-vector products are evaluated.
163
+
164
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
165
+ This requires creating a graph for the gradient.
166
+ - ``"forward"``: Use a forward finite difference formula to
167
+ approximate the HVP. This requires one extra gradient evaluation.
168
+ - ``"central"``: Use a central finite difference formula for a
169
+ more accurate HVP approximation. This requires two extra
170
+ gradient evaluations.
171
+ Defaults to "autograd".
172
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
173
+ inner (Chainable | None, optional): modules to apply hessian preconditioner to. Defaults to None.
174
+ seed (int | None, optional): seed for random generator. Defaults to None.
175
+
176
+ Examples:
177
+
178
+ NystromPCG with backtracking line search
179
+
180
+ .. code-block:: python
181
+
182
+ opt = tz.Modular(
183
+ model.parameters(),
184
+ tz.m.NystromPCG(10),
185
+ tz.m.Backtracking()
186
+ )
187
+
188
+ Reference:
189
+ Frangella, Z., Tropp, J. A., & Udell, M. (2023). Randomized nyström preconditioning. SIAM Journal on Matrix Analysis and Applications, 44(2), 718-752. https://arxiv.org/abs/2110.02820
190
+
191
+ """
89
192
  def __init__(
90
193
  self,
91
194
  sketch_size: int,
@@ -64,6 +64,40 @@ def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
64
64
  m.reset()
65
65
 
66
66
  class GaussianHomotopy(Reformulation):
67
+ """Approximately smoothes the function with a gaussian kernel by sampling it at random perturbed points around current point. Both function values and gradients are averaged over all samples. The perturbed points are generated before each
68
+ step and remain the same throughout the step.
69
+
70
+ .. note::
71
+ This module reformulates the objective, it modifies the closure to evaluate value and gradients of a smoothed function. All modules after this will operate on the modified objective.
72
+
73
+ .. note::
74
+ This module requires the a closure passed to the optimizer step,
75
+ as it needs to re-evaluate the loss and gradients at perturbed points.
76
+
77
+ Args:
78
+ n_samples (int): number of points to sample, larger values lead to a more accurate smoothing.
79
+ init_sigma (float): initial scale of perturbations.
80
+ tol (float | None, optional):
81
+ if maximal parameters change value is smaller than this, sigma is reduced by :code:`decay`. Defaults to 1e-4.
82
+ decay (float, optional): multiplier to sigma when converged on a smoothed function. Defaults to 0.5.
83
+ max_steps (int | None, optional): maximum number of steps before decaying sigma. Defaults to None.
84
+ clear_state (bool, optional):
85
+ whether to clear all other module states when sigma is decayed, because the objective function changes. Defaults to True.
86
+ seed (int | None, optional): seed for random perturbationss. Defaults to None.
87
+
88
+ Examples:
89
+ Gaussian-smoothed NewtonCG
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.GaussianHomotopy(100),
96
+ tz.m.NewtonCG(maxiter=20),
97
+ tz.m.AdaptiveBacktracking(),
98
+ )
99
+
100
+ """
67
101
  def __init__(
68
102
  self,
69
103
  n_samples: int,
@@ -56,7 +56,7 @@ def _precompute_denominator(tensor: torch.Tensor, sigma) -> torch.Tensor:
56
56
  return 1 - sigma * torch.fft.fft(v) # pylint: disable = not-callable
57
57
 
58
58
  class LaplacianSmoothing(Transform):
59
- """Applies laplacian smoothing via a fast Fourier transform solver.
59
+ """Applies laplacian smoothing via a fast Fourier transform solver which can improve generalization.
60
60
 
61
61
  Args:
62
62
  sigma (float, optional): controls the amount of smoothing. Defaults to 1.
@@ -69,9 +69,19 @@ class LaplacianSmoothing(Transform):
69
69
  target (str, optional):
70
70
  what to set on var.
71
71
 
72
+ Examples:
73
+ Laplacian Smoothing Gradient Descent optimizer as in the paper
74
+
75
+ .. code-block:: python
76
+
77
+ opt = tz.Modular(
78
+ model.parameters(),
79
+ tz.m.LaplacianSmoothing(),
80
+ tz.m.LR(1e-2),
81
+ )
82
+
72
83
  Reference:
73
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
74
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
84
+ Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022). Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.
75
85
 
76
86
  """
77
87
  def __init__(self, sigma:float = 1, layerwise=True, min_numel = 4, target: Target = 'update'):
@@ -82,7 +92,7 @@ class LaplacianSmoothing(Transform):
82
92
 
83
93
 
84
94
  @torch.no_grad
85
- def apply(self, tensors, params, grads, loss, states, settings):
95
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
86
96
  layerwise = settings[0]['layerwise']
87
97
 
88
98
  # layerwise laplacian smoothing
@@ -0,0 +1,2 @@
1
+ from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein
@@ -0,0 +1,122 @@
1
+ """Various step size strategies"""
2
+ from typing import Any, Literal
3
+ from operator import itemgetter
4
+ import torch
5
+
6
+ from ...core import Transform, Chainable
7
+ from ...utils import TensorList, unpack_dicts, unpack_states, NumberList
8
+
9
+
10
+ class PolyakStepSize(Transform):
11
+ """Polyak's subgradient method.
12
+
13
+ Args:
14
+ f_star (int, optional):
15
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
16
+ max (float | None, optional): maximum possible step size. Defaults to None.
17
+ use_grad (bool, optional):
18
+ if True, uses dot product of update and gradient to compute the step size.
19
+ Otherwise, dot product of update with itself is used, which has no geometric meaning so it probably won't work well.
20
+ Defaults to False.
21
+ alpha (float, optional): multiplier to Polyak step-size. Defaults to 1.
22
+ """
23
+ def __init__(self, f_star: float = 0, max: float | None = None, use_grad=False, alpha: float = 1, inner: Chainable | None = None):
24
+
25
+ defaults = dict(alpha=alpha, max=max, f_star=f_star, use_grad=use_grad)
26
+ super().__init__(defaults, uses_grad=use_grad, uses_loss=True, inner=inner)
27
+
28
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
29
+ assert grads is not None and loss is not None
30
+ tensors = TensorList(tensors)
31
+ grads = TensorList(grads)
32
+
33
+ use_grad, max, f_star = itemgetter('use_grad', 'max', 'f_star')(settings[0])
34
+
35
+ if use_grad: gg = tensors.dot(grads)
36
+ else: gg = tensors.dot(tensors)
37
+
38
+ if gg.abs() <= torch.finfo(gg.dtype).eps: step_size = 0 # converged
39
+ else: step_size = (loss - f_star) / gg
40
+
41
+ if max is not None:
42
+ if step_size > max: step_size = max
43
+
44
+ self.global_state['step_size'] = step_size
45
+
46
+ @torch.no_grad
47
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
48
+ step_size = self.global_state.get('step_size', 1)
49
+ torch._foreach_mul_(tensors, step_size * unpack_dicts(settings, 'alpha', cls=NumberList))
50
+ return tensors
51
+
52
+
53
+
54
+ def _bb_short(s: TensorList, y: TensorList, sy, eps, fallback):
55
+ yy = y.dot(y)
56
+ if yy < eps:
57
+ if sy < eps: return fallback # try to fallback on long
58
+ ss = s.dot(s)
59
+ return ss/sy
60
+ return sy/yy
61
+
62
+ def _bb_long(s: TensorList, y: TensorList, sy, eps, fallback):
63
+ ss = s.dot(s)
64
+ if sy < eps:
65
+ yy = y.dot(y) # try to fallback on short
66
+ if yy < eps: return fallback
67
+ return sy/yy
68
+ return ss/sy
69
+
70
+ def _bb_geom(s: TensorList, y: TensorList, sy, eps, fallback):
71
+ short = _bb_short(s, y, sy, eps, fallback)
72
+ long = _bb_long(s, y, sy, eps, fallback)
73
+ return (short * long) ** 0.5
74
+
75
+ class BarzilaiBorwein(Transform):
76
+ """Barzilai-Borwein method.
77
+
78
+ Args:
79
+ type (str, optional):
80
+ one of "short" with formula sᵀy/yᵀy, "long" with formula sᵀs/sᵀy, or "geom" to use geometric mean of short and long.
81
+ Defaults to 'geom'.
82
+ scale_first (bool, optional):
83
+ whether to make first step very small when previous gradient is not available. Defaults to True.
84
+ fallback (float, optional): step size when denominator is less than 0 (will happen on negative curvature). Defaults to 1e-3.
85
+ inner (Chainable | None, optional):
86
+ step size will be applied to outputs of this module. Defaults to None.
87
+
88
+ """
89
+ def __init__(self, type: Literal['long', 'short', 'geom'] = 'geom', scale_first:bool=True, fallback:float=1e-3, inner:Chainable|None = None):
90
+ defaults = dict(type=type, fallback=fallback)
91
+ super().__init__(defaults, uses_grad=False, scale_first=scale_first, inner=inner)
92
+
93
+ def reset_for_online(self):
94
+ super().reset_for_online()
95
+ self.clear_state_keys('prev_p', 'prev_g')
96
+
97
+ @torch.no_grad
98
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
99
+ prev_p, prev_g = unpack_states(states, tensors, 'prev_p', 'prev_g', cls=TensorList)
100
+ fallback = unpack_dicts(settings, 'fallback', cls=NumberList)
101
+ type = settings[0]['type']
102
+
103
+ s = params-prev_p
104
+ y = tensors-prev_g
105
+ sy = s.dot(y)
106
+ eps = torch.finfo(sy.dtype).eps
107
+
108
+ if type == 'short': step_size = _bb_short(s, y, sy, eps, fallback)
109
+ elif type == 'long': step_size = _bb_long(s, y, sy, eps, fallback)
110
+ elif type == 'geom': step_size = _bb_geom(s, y, sy, eps, fallback)
111
+ else: raise ValueError(type)
112
+
113
+ self.global_state['step_size'] = step_size
114
+
115
+ prev_p.copy_(params)
116
+ prev_g.copy_(tensors)
117
+
118
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
119
+ step_size = self.global_state.get('step_size', 1)
120
+ torch._foreach_mul_(tensors, step_size)
121
+ return tensors
122
+
@@ -0,0 +1,154 @@
1
+ """Learning rate"""
2
+ import torch
3
+ import random
4
+
5
+ from ...core import Transform
6
+ from ...utils import NumberList, TensorList, generic_ne, unpack_dicts
7
+
8
+ def lazy_lr(tensors: TensorList, lr: float | list, inplace:bool):
9
+ """multiplies by lr if lr is not 1"""
10
+ if generic_ne(lr, 1):
11
+ if inplace: return tensors.mul_(lr)
12
+ return tensors * lr
13
+ return tensors
14
+
15
+ class LR(Transform):
16
+ """Learning rate. Adding this module also adds support for LR schedulers."""
17
+ def __init__(self, lr: float):
18
+ defaults=dict(lr=lr)
19
+ super().__init__(defaults, uses_grad=False)
20
+
21
+ @torch.no_grad
22
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
23
+ return lazy_lr(TensorList(tensors), lr=[s['lr'] for s in settings], inplace=True)
24
+
25
+ class StepSize(Transform):
26
+ """this is exactly the same as LR, except the `lr` parameter can be renamed to any other name to avoid clashes"""
27
+ def __init__(self, step_size: float, key = 'step_size'):
28
+ defaults={"key": key, key: step_size}
29
+ super().__init__(defaults, uses_grad=False)
30
+
31
+ @torch.no_grad
32
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
33
+ return lazy_lr(TensorList(tensors), lr=[s[s['key']] for s in settings], inplace=True)
34
+
35
+
36
+ def _warmup_lr(step: int, start_lr: float | NumberList, end_lr: float | NumberList, steps: float):
37
+ """returns warm up lr scalar"""
38
+ if step > steps: return end_lr
39
+ return start_lr + (end_lr - start_lr) * (step / steps)
40
+
41
+ class Warmup(Transform):
42
+ """Learning rate warmup, linearly increases learning rate multiplier from :code:`start_lr` to :code:`end_lr` over :code:`steps` steps.
43
+
44
+ Args:
45
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
46
+ start_lr (_type_, optional): initial learning rate multiplier on first step. Defaults to 1e-5.
47
+ end_lr (float, optional): learning rate multiplier at the end and after warmup. Defaults to 1.
48
+
49
+ Example:
50
+ Adam with 1000 steps warmup
51
+
52
+ .. code-block:: python
53
+
54
+ opt = tz.Modular(
55
+ model.parameters(),
56
+ tz.m.Adam(),
57
+ tz.m.LR(1e-2),
58
+ tz.m.Warmup(steps=1000)
59
+ )
60
+
61
+ """
62
+ def __init__(self, steps = 100, start_lr = 1e-5, end_lr:float = 1):
63
+ defaults = dict(start_lr=start_lr,end_lr=end_lr, steps=steps)
64
+ super().__init__(defaults, uses_grad=False)
65
+
66
+ @torch.no_grad
67
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
68
+ start_lr, end_lr = unpack_dicts(settings, 'start_lr', 'end_lr', cls = NumberList)
69
+ num_steps = settings[0]['steps']
70
+ step = self.global_state.get('step', 0)
71
+
72
+ tensors = lazy_lr(
73
+ TensorList(tensors),
74
+ lr=_warmup_lr(step=step, start_lr=start_lr, end_lr=end_lr, steps=num_steps),
75
+ inplace=True
76
+ )
77
+ self.global_state['step'] = step + 1
78
+ return tensors
79
+
80
+ class WarmupNormClip(Transform):
81
+ """Warmup via clipping of the update norm.
82
+
83
+ Args:
84
+ start_norm (_type_, optional): maximal norm on the first step. Defaults to 1e-5.
85
+ end_norm (float, optional): maximal norm on the last step. After that, norm clipping is disabled. Defaults to 1.
86
+ steps (int, optional): number of steps to perform warmup for. Defaults to 100.
87
+
88
+ Example:
89
+ Adam with 1000 steps norm clip warmup
90
+
91
+ .. code-block:: python
92
+
93
+ opt = tz.Modular(
94
+ model.parameters(),
95
+ tz.m.Adam(),
96
+ tz.m.WarmupNormClip(steps=1000)
97
+ tz.m.LR(1e-2),
98
+ )
99
+ """
100
+ def __init__(self, steps = 100, start_norm = 1e-5, end_norm:float = 1):
101
+ defaults = dict(start_norm=start_norm,end_norm=end_norm, steps=steps)
102
+ super().__init__(defaults, uses_grad=False)
103
+
104
+ @torch.no_grad
105
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
106
+ start_norm, end_norm = unpack_dicts(settings, 'start_norm', 'end_norm', cls = NumberList)
107
+ num_steps = settings[0]['steps']
108
+ step = self.global_state.get('step', 0)
109
+ if step > num_steps: return tensors
110
+
111
+ tensors = TensorList(tensors)
112
+ norm = tensors.global_vector_norm()
113
+ current_max_norm = _warmup_lr(step, start_norm[0], end_norm[0], num_steps)
114
+ if norm > current_max_norm:
115
+ tensors.mul_(current_max_norm / norm)
116
+
117
+ self.global_state['step'] = step + 1
118
+ return tensors
119
+
120
+
121
+ class RandomStepSize(Transform):
122
+ """Uses random global or layer-wise step size from `low` to `high`.
123
+
124
+ Args:
125
+ low (float, optional): minimum learning rate. Defaults to 0.
126
+ high (float, optional): maximum learning rate. Defaults to 1.
127
+ parameterwise (bool, optional):
128
+ if True, generate random step size for each parameter separately,
129
+ if False generate one global random step size. Defaults to False.
130
+ """
131
+ def __init__(self, low: float = 0, high: float = 1, parameterwise=False, seed:int|None=None):
132
+ defaults = dict(low=low, high=high, parameterwise=parameterwise,seed=seed)
133
+ super().__init__(defaults, uses_grad=False)
134
+
135
+ @torch.no_grad
136
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
137
+ s = settings[0]
138
+ parameterwise = s['parameterwise']
139
+
140
+ seed = s['seed']
141
+ if 'generator' not in self.global_state:
142
+ self.global_state['generator'] = random.Random(seed)
143
+ generator: random.Random = self.global_state['generator']
144
+
145
+ if parameterwise:
146
+ low, high = unpack_dicts(settings, 'low', 'high')
147
+ lr = [generator.uniform(l, h) for l, h in zip(low, high)]
148
+ else:
149
+ low = s['low']
150
+ high = s['high']
151
+ lr = generator.uniform(low, high)
152
+
153
+ torch._foreach_mul_(tensors, lr)
154
+ return tensors
@@ -1 +1 @@
1
- from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, NormalizedWeightDecay
1
+ from .weight_decay import WeightDecay, DirectWeightDecay, decay_weights_, RelativeWeightDecay
@@ -22,22 +22,99 @@ def weight_decay_(
22
22
 
23
23
 
24
24
  class WeightDecay(Transform):
25
+ """Weight decay.
26
+
27
+ Args:
28
+ weight_decay (float): weight decay scale.
29
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
30
+ target (Target, optional): what to set on var. Defaults to 'update'.
31
+
32
+ Examples:
33
+ Adam with non-decoupled weight decay
34
+
35
+ .. code-block:: python
36
+
37
+ opt = tz.Modular(
38
+ model.parameters(),
39
+ tz.m.WeightDecay(1e-3),
40
+ tz.m.Adam(),
41
+ tz.m.LR(1e-3)
42
+ )
43
+
44
+ Adam with decoupled weight decay that still scales with learning rate
45
+
46
+ .. code-block:: python
47
+
48
+ opt = tz.Modular(
49
+ model.parameters(),
50
+ tz.m.Adam(),
51
+ tz.m.WeightDecay(1e-3),
52
+ tz.m.LR(1e-3)
53
+ )
54
+
55
+ Adam with fully decoupled weight decay that doesn't scale with learning rate
56
+
57
+ .. code-block:: python
58
+
59
+ opt = tz.Modular(
60
+ model.parameters(),
61
+ tz.m.Adam(),
62
+ tz.m.LR(1e-3),
63
+ tz.m.WeightDecay(1e-6)
64
+ )
65
+
66
+ """
25
67
  def __init__(self, weight_decay: float, ord: int = 2, target: Target = 'update'):
68
+
26
69
  defaults = dict(weight_decay=weight_decay, ord=ord)
27
70
  super().__init__(defaults, uses_grad=False, target=target)
28
71
 
29
72
  @torch.no_grad
30
- def apply(self, tensors, params, grads, loss, states, settings):
73
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
31
74
  weight_decay = NumberList(s['weight_decay'] for s in settings)
32
75
  ord = settings[0]['ord']
33
76
 
34
77
  return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay, ord)
35
78
 
36
- class NormalizedWeightDecay(Transform):
79
+ class RelativeWeightDecay(Transform):
80
+ """Weight decay relative to the mean absolute value of update, gradient or parameters depending on value of :code:`norm_input` argument.
81
+
82
+ Args:
83
+ weight_decay (float): relative weight decay scale.
84
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
85
+ norm_input (str, optional):
86
+ determines what should weight decay be relative to. "update", "grad" or "params".
87
+ Defaults to "update".
88
+ target (Target, optional): what to set on var. Defaults to 'update'.
89
+
90
+ Examples:
91
+ Adam with non-decoupled relative weight decay
92
+
93
+ .. code-block:: python
94
+
95
+ opt = tz.Modular(
96
+ model.parameters(),
97
+ tz.m.RelativeWeightDecay(1e-3),
98
+ tz.m.Adam(),
99
+ tz.m.LR(1e-3)
100
+ )
101
+
102
+ Adam with decoupled relative weight decay
103
+
104
+ .. code-block:: python
105
+
106
+ opt = tz.Modular(
107
+ model.parameters(),
108
+ tz.m.Adam(),
109
+ tz.m.RelativeWeightDecay(1e-3),
110
+ tz.m.LR(1e-3)
111
+ )
112
+
113
+ """
37
114
  def __init__(
38
115
  self,
39
116
  weight_decay: float = 0.1,
40
- ord: int = 2,
117
+ ord: int = 2,
41
118
  norm_input: Literal["update", "grad", "params"] = "update",
42
119
  target: Target = "update",
43
120
  ):
@@ -45,7 +122,7 @@ class NormalizedWeightDecay(Transform):
45
122
  super().__init__(defaults, uses_grad=norm_input == 'grad', target=target)
46
123
 
47
124
  @torch.no_grad
48
- def apply(self, tensors, params, grads, loss, states, settings):
125
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
49
126
  weight_decay = NumberList(s['weight_decay'] for s in settings)
50
127
 
51
128
  ord = settings[0]['ord']
@@ -60,9 +137,9 @@ class NormalizedWeightDecay(Transform):
60
137
  else:
61
138
  raise ValueError(norm_input)
62
139
 
63
- norm = src.global_vector_norm(ord)
140
+ mean_abs = src.abs().global_mean()
64
141
 
65
- return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * norm, ord)
142
+ return weight_decay_(as_tensorlist(tensors), as_tensorlist(params), weight_decay * mean_abs, ord)
66
143
 
67
144
 
68
145
  @torch.no_grad
@@ -72,7 +149,12 @@ def decay_weights_(params: Iterable[torch.Tensor], weight_decay: float | NumberL
72
149
  weight_decay_(params, params, -weight_decay, ord)
73
150
 
74
151
  class DirectWeightDecay(Module):
75
- """directly decays weights in-place"""
152
+ """Directly applies weight decay to parameters.
153
+
154
+ Args:
155
+ weight_decay (float): weight decay scale.
156
+ ord (int, optional): order of the penalty, e.g. 1 for L1 and 2 for L2. Defaults to 2.
157
+ """
76
158
  def __init__(self, weight_decay: float, ord: int = 2,):
77
159
  defaults = dict(weight_decay=weight_decay, ord=ord)
78
160
  super().__init__(defaults)
@@ -7,7 +7,35 @@ from ...utils import Params, _copy_param_groups, _make_param_groups
7
7
 
8
8
 
9
9
  class Wrap(Module):
10
- """Custom param groups are supported only by `set_param_groups`. Settings passed to Modular will be ignored."""
10
+ """
11
+ Wraps a pytorch optimizer to use it as a module.
12
+
13
+ .. note::
14
+ Custom param groups are supported only by `set_param_groups`, settings passed to Modular will be ignored.
15
+
16
+ Args:
17
+ opt_fn (Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer):
18
+ function that takes in parameters and returns the optimizer, for example :code:`torch.optim.Adam`
19
+ or :code:`lambda parameters: torch.optim.Adam(parameters, lr=1e-3)`
20
+ *args:
21
+ **kwargs:
22
+ Extra args to be passed to opt_fn. The function is called as :code:`opt_fn(parameters, *args, **kwargs)`.
23
+
24
+ Example:
25
+ wrapping pytorch_optimizer.StableAdamW
26
+
27
+ .. code-block:: py
28
+
29
+ from pytorch_optimizer import StableAdamW
30
+ opt = tz.Modular(
31
+ model.parameters(),
32
+ tz.m.Wrap(StableAdamW, lr=1),
33
+ tz.m.Cautious(),
34
+ tz.m.LR(1e-2)
35
+ )
36
+
37
+
38
+ """
11
39
  def __init__(self, opt_fn: Callable[..., torch.optim.Optimizer] | torch.optim.Optimizer, *args, **kwargs):
12
40
  super().__init__()
13
41
  self._opt_fn = opt_fn