torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -6,36 +6,35 @@ from ...core import Module, Target, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
7
 
8
8
 
9
- class Clone(Transform):
10
- def __init__(self): super().__init__({}, uses_grad=False)
11
- @torch.no_grad
12
- def apply(self, tensors, params, grads, loss, states, settings): return [t.clone() for t in tensors]
13
-
14
- class Grad(Module):
9
+ class Clone(Module):
10
+ """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
15
11
  def __init__(self):
16
12
  super().__init__({})
17
13
  @torch.no_grad
18
14
  def step(self, var):
19
- var.update = [g.clone() for g in var.get_grad()]
15
+ var.update = [u.clone() for u in var.get_update()]
20
16
  return var
21
17
 
22
- class Params(Module):
18
+ class Grad(Module):
19
+ """Outputs the gradient"""
23
20
  def __init__(self):
24
21
  super().__init__({})
25
22
  @torch.no_grad
26
23
  def step(self, var):
27
- var.update = [p.clone() for p in var.params]
24
+ var.update = [g.clone() for g in var.get_grad()]
28
25
  return var
29
26
 
30
- class Update(Module):
27
+ class Params(Module):
28
+ """Outputs parameters"""
31
29
  def __init__(self):
32
30
  super().__init__({})
33
31
  @torch.no_grad
34
32
  def step(self, var):
35
- var.update = [u.clone() for u in var.get_update()]
33
+ var.update = [p.clone() for p in var.params]
36
34
  return var
37
35
 
38
36
  class Zeros(Module):
37
+ """Outputs zeros"""
39
38
  def __init__(self):
40
39
  super().__init__({})
41
40
  @torch.no_grad
@@ -44,6 +43,7 @@ class Zeros(Module):
44
43
  return var
45
44
 
46
45
  class Ones(Module):
46
+ """Outputs ones"""
47
47
  def __init__(self):
48
48
  super().__init__({})
49
49
  @torch.no_grad
@@ -52,6 +52,7 @@ class Ones(Module):
52
52
  return var
53
53
 
54
54
  class Fill(Module):
55
+ """Outputs tensors filled with :code:`value`"""
55
56
  def __init__(self, value: float):
56
57
  defaults = dict(value=value)
57
58
  super().__init__(defaults)
@@ -62,6 +63,7 @@ class Fill(Module):
62
63
  return var
63
64
 
64
65
  class RandomSample(Module):
66
+ """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
65
67
  def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
66
68
  defaults = dict(eps=eps, distribution=distribution)
67
69
  super().__init__(defaults)
@@ -74,6 +76,7 @@ class RandomSample(Module):
74
76
  return var
75
77
 
76
78
  class Randn(Module):
79
+ """Outputs tensors filled with random numbers from a normal distribution with mean 0 and variance 1."""
77
80
  def __init__(self):
78
81
  super().__init__({})
79
82
 
@@ -83,6 +86,7 @@ class Randn(Module):
83
86
  return var
84
87
 
85
88
  class Uniform(Module):
89
+ """Outputs tensors filled with random numbers from uniform distribution between :code:`low` and :code:`high`."""
86
90
  def __init__(self, low: float, high: float):
87
91
  defaults = dict(low=low, high=high)
88
92
  super().__init__(defaults)
@@ -94,19 +98,23 @@ class Uniform(Module):
94
98
  return var
95
99
 
96
100
  class GradToNone(Module):
101
+ """Sets :code:`grad` attribute to None on :code:`var`."""
97
102
  def __init__(self): super().__init__()
98
103
  def step(self, var):
99
104
  var.grad = None
100
105
  return var
101
106
 
102
107
  class UpdateToNone(Module):
108
+ """Sets :code:`update` attribute to None on :code:`var`."""
103
109
  def __init__(self): super().__init__()
104
110
  def step(self, var):
105
111
  var.update = None
106
112
  return var
107
113
 
108
114
  class Identity(Module):
115
+ """A placeholder identity operator that is argument-insensitive."""
109
116
  def __init__(self, *args, **kwargs): super().__init__()
110
117
  def step(self, var): return var
111
118
 
112
- NoOp = Identity
119
+ NoOp = Identity
120
+ """A placeholder identity operator that is argument-insensitive."""
@@ -1,7 +1,18 @@
1
1
  from .adagrad import Adagrad, FullMatrixAdagrad
2
+
3
+ # from .curveball import CurveBall
4
+ # from .spectral import SpectralPreconditioner
5
+ from .adahessian import AdaHessian
2
6
  from .adam import Adam
7
+ from .adan import Adan
8
+ from .adaptive_heavyball import AdaptiveHeavyBall
9
+ from .esgd import ESGD
10
+ from .ladagrad import LMAdagrad
3
11
  from .lion import Lion
12
+ from .mars import MARSCorrection
13
+ from .msam import MSAM, MSAMObjective
4
14
  from .muon import DualNormCorrection, MuonAdjustLR, Orthogonalize, orthogonalize_grads_
15
+ from .orthograd import OrthoGrad, orthograd_
5
16
  from .rmsprop import RMSprop
6
17
  from .rprop import (
7
18
  BacktrackOnSignChange,
@@ -10,9 +21,7 @@ from .rprop import (
10
21
  SignConsistencyLRs,
11
22
  SignConsistencyMask,
12
23
  )
24
+ from .sam import ASAM, SAM
13
25
  from .shampoo import Shampoo
14
26
  from .soap import SOAP
15
- from .orthograd import OrthoGrad, orthograd_
16
27
  from .sophia_h import SophiaH
17
- # from .curveball import CurveBall
18
- # from .spectral import SpectralPreconditioner
@@ -25,6 +25,7 @@ def adagrad_(
25
25
  step: int,
26
26
  pow: float = 2,
27
27
  use_sqrt: bool = True,
28
+ divide: bool = False,
28
29
 
29
30
  # inner args
30
31
  inner: Module | None = None,
@@ -40,6 +41,8 @@ def adagrad_(
40
41
  assert params is not None
41
42
  tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
42
43
 
44
+ if divide: sq_sum_ = sq_sum_ / max(step, 1)
45
+
43
46
  if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
44
47
  else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
45
48
 
@@ -48,7 +51,9 @@ def adagrad_(
48
51
 
49
52
 
50
53
  class Adagrad(Transform):
51
- """Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
54
+ """Adagrad, divides by sum of past squares of gradients.
55
+
56
+ This implementation is identical to :code:`torch.optim.Adagrad`.
52
57
 
53
58
  Args:
54
59
  lr_decay (float, optional): learning rate decay. Defaults to 0.
@@ -67,23 +72,24 @@ class Adagrad(Transform):
67
72
  alpha: float = 1,
68
73
  pow: float = 2,
69
74
  use_sqrt: bool = True,
75
+ divide: bool=False,
70
76
  inner: Chainable | None = None,
71
77
  ):
72
78
  defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
73
- eps = eps, pow=pow, use_sqrt = use_sqrt)
79
+ eps = eps, pow=pow, use_sqrt = use_sqrt, divide=divide)
74
80
  super().__init__(defaults=defaults, uses_grad=False)
75
81
 
76
82
  if inner is not None:
77
83
  self.set_child('inner', inner)
78
84
 
79
85
  @torch.no_grad
80
- def apply(self, tensors, params, grads, loss, states, settings):
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
81
87
  tensors = TensorList(tensors)
82
88
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
83
89
 
84
90
  lr_decay,alpha,eps = unpack_dicts(settings, 'lr_decay', 'alpha', 'eps', cls=NumberList)
85
91
 
86
- pow, use_sqrt = itemgetter('pow', 'use_sqrt')(settings[0])
92
+ pow, use_sqrt, divide = itemgetter('pow', 'use_sqrt', 'divide')(settings[0])
87
93
 
88
94
  sq_sum = unpack_states(states, tensors, 'sq_sum', cls=TensorList)
89
95
 
@@ -100,6 +106,7 @@ class Adagrad(Transform):
100
106
  step=self.global_state["step"],
101
107
  pow=pow,
102
108
  use_sqrt=use_sqrt,
109
+ divide=divide,
103
110
 
104
111
  # inner args
105
112
  inner=self.children.get("inner", None),
@@ -110,17 +117,17 @@ class Adagrad(Transform):
110
117
 
111
118
 
112
119
  class FullMatrixAdagrad(TensorwiseTransform):
113
- def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=False, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', inner: Chainable | None = None):
114
- defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init)
115
- super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
120
+ def __init__(self, beta: float | None = None, decay: float | None = None, sqrt:bool=True, concat_params=True, update_freq=1, init: Literal['identity', 'zeros', 'ones', 'GGT'] = 'identity', divide: bool=False, inner: Chainable | None = None):
121
+ defaults = dict(beta=beta, decay=decay, sqrt=sqrt, init=init, divide=divide)
122
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner,)
116
123
 
117
124
  @torch.no_grad
118
- def update_tensor(self, tensor, param, grad, loss, state, settings):
125
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
119
126
  G = tensor.ravel()
120
127
  GG = torch.outer(G, G)
121
- decay = settings['decay']
122
- beta = settings['beta']
123
- init = settings['init']
128
+ decay = setting['decay']
129
+ beta = setting['beta']
130
+ init = setting['init']
124
131
 
125
132
  if 'GG' not in state:
126
133
  if init == 'identity': state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
@@ -132,11 +139,14 @@ class FullMatrixAdagrad(TensorwiseTransform):
132
139
 
133
140
  if beta is not None: state['GG'].lerp_(GG, 1-beta)
134
141
  else: state['GG'].add_(GG)
142
+ state['i'] = state.get('i', 0) + 1 # number of GGTs in sum
135
143
 
136
144
  @torch.no_grad
137
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
145
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
138
146
  GG = state['GG']
139
- sqrt = settings['sqrt']
147
+ sqrt = setting['sqrt']
148
+ divide = setting['divide']
149
+ if divide: GG = GG/state.get('i', 1)
140
150
 
141
151
  if tensor.numel() == 1:
142
152
  GG = GG.squeeze()
@@ -0,0 +1,223 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Target, Transform, apply_transform
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
10
+
11
+
12
+ def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
13
+ """averages x over first dimension in blocks"""
14
+ if enable and x.ndim >= 2:
15
+ if math.prod(x.shape[1:]) <= 1: return x
16
+ size = x.size(0)
17
+ if block_size is None: return x.mean(0, keepdim=True)
18
+
19
+ n_blocks = size // block_size
20
+ if n_blocks <= 1: return x.mean(0, keepdim = True)
21
+
22
+ n_remaining = size - n_blocks * block_size
23
+ remaining = None
24
+ if n_remaining > 0:
25
+ remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
26
+ x = x[:-n_remaining]
27
+
28
+ x = x.view(block_size, n_blocks, *x.shape[1:])
29
+ x_mean = x.mean(0).repeat_interleave(block_size, 0)
30
+
31
+ if remaining is None: return x_mean
32
+ return torch.cat([x_mean, remaining], 0)
33
+
34
+ return x
35
+
36
+ def _rademacher_like(tensor, p = 0.5, generator = None):
37
+ """p is probability of a 1, other values will be -1."""
38
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
39
+
40
+ def adahessian(
41
+ tensors: TensorList,
42
+ D: TensorList | None,
43
+ exp_avg_: TensorList,
44
+ D_exp_avg_sq_: TensorList,
45
+ beta1: float | NumberList,
46
+ beta2: float | NumberList,
47
+ update_freq: int,
48
+ eps: float | NumberList,
49
+ step: int,
50
+ ):
51
+ # momentum
52
+ exp_avg_.lerp_(tensors, 1-beta1)
53
+ num = exp_avg_ / (1-beta1)
54
+
55
+ # update preconditioner
56
+ if step % update_freq == 0:
57
+ assert D is not None
58
+ D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
59
+
60
+ else:
61
+ assert D is None
62
+
63
+ denom = (D_exp_avg_sq_ / (1-beta2)).sqrt_().add_(eps)
64
+
65
+ return num.div_(denom)
66
+
67
+
68
+ class AdaHessian(Module):
69
+ """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
70
+
71
+ This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
72
+
73
+ .. note::
74
+ In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply AdaHessian preconditioning to another module's output.
75
+
76
+ .. note::
77
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
78
+
79
+ .. note::
80
+ This module requires a closure passed to the optimizer step,
81
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
82
+ The closure must accept a ``backward`` argument (refer to documentation).
83
+
84
+ Args:
85
+ beta1 (float, optional): first momentum. Defaults to 0.9.
86
+ beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
87
+ averaging (bool, optional):
88
+ whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
89
+ This can be set per-parameter in param groups.
90
+ block_size (int, optional):
91
+ size of block in the block-diagonal averaging.
92
+ update_freq (int, optional):
93
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
94
+ This value can be increased to reduce computational cost. Defaults to 1.
95
+ eps (float, optional):
96
+ division stability epsilon. Defaults to 1e-8.
97
+ hvp_method (str, optional):
98
+ Determines how Hessian-vector products are evaluated.
99
+
100
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
101
+ This requires creating a graph for the gradient.
102
+ - ``"forward"``: Use a forward finite difference formula to
103
+ approximate the HVP. This requires one extra gradient evaluation.
104
+ - ``"central"``: Use a central finite difference formula for a
105
+ more accurate HVP approximation. This requires two extra
106
+ gradient evaluations.
107
+ Defaults to "autograd".
108
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
109
+ n_samples (int, optional):
110
+ number of hessian-vector products with random vectors to evaluate each time when updating
111
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
112
+ seed (int | None, optional): seed for random vectors. Defaults to None.
113
+ inner (Chainable | None, optional):
114
+ Inner module. If this is specified, operations are performed in the following order.
115
+ 1. compute hessian diagonal estimate.
116
+ 2. pass inputs to :code:`inner`.
117
+ 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
118
+
119
+ Examples:
120
+ Using AdaHessian:
121
+
122
+ .. code-block:: python
123
+
124
+ opt = tz.Modular(
125
+ model.parameters(),
126
+ tz.m.AdaHessian(),
127
+ tz.m.LR(0.1)
128
+ )
129
+
130
+ AdaHessian preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
131
+ Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
132
+ AdaHessian preconditioning to nesterov momentum (:code:`tz.m.NAG`):
133
+
134
+ .. code-block:: python
135
+
136
+ opt = tz.Modular(
137
+ model.parameters(),
138
+ tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
139
+ tz.m.LR(0.1)
140
+ )
141
+
142
+ """
143
+ def __init__(
144
+ self,
145
+ beta1: float = 0.9,
146
+ beta2: float = 0.999,
147
+ averaging: bool = False,
148
+ block_size: int | None = 9,
149
+ update_freq: int = 1,
150
+ eps: float = 1e-8,
151
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
152
+ fd_h: float = 1e-3,
153
+ n_samples = 1,
154
+ seed: int | None = None,
155
+ inner: Chainable | None = None
156
+ ):
157
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
158
+ super().__init__(defaults)
159
+
160
+ if inner is not None:
161
+ self.set_child('inner', inner)
162
+
163
+ @torch.no_grad
164
+ def step(self, var):
165
+ params = var.params
166
+ settings = self.settings[params[0]]
167
+ hvp_method = settings['hvp_method']
168
+ fd_h = settings['fd_h']
169
+ update_freq = settings['update_freq']
170
+ n_samples = settings['n_samples']
171
+
172
+ seed = settings['seed']
173
+ generator = None
174
+ if seed is not None:
175
+ if 'generator' not in self.global_state:
176
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
177
+ generator = self.global_state['generator']
178
+
179
+ beta1, beta2, eps, averaging, block_size = self.get_settings(params,
180
+ 'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
181
+
182
+ exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
183
+
184
+ step = self.global_state.get('step', 0)
185
+ self.global_state['step'] = step + 1
186
+
187
+ closure = var.closure
188
+ assert closure is not None
189
+
190
+ D = None
191
+ if step % update_freq == 0:
192
+
193
+ rgrad=None
194
+ for i in range(n_samples):
195
+ u = [_rademacher_like(p, generator=generator) for p in params]
196
+
197
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
198
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
199
+
200
+ if D is None: D = Hvp
201
+ else: torch._foreach_add_(D, Hvp)
202
+
203
+ assert D is not None
204
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
205
+
206
+ D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
207
+
208
+ update = var.get_update()
209
+ if 'inner' in self.children:
210
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
211
+
212
+ var.update = adahessian(
213
+ tensors=TensorList(update),
214
+ D=TensorList(D) if D is not None else None,
215
+ exp_avg_=exp_avg,
216
+ D_exp_avg_sq_=D_exp_avg_sq,
217
+ beta1=beta1,
218
+ beta2=beta2,
219
+ update_freq=update_freq,
220
+ eps=eps,
221
+ step=step,
222
+ )
223
+ return var
@@ -10,7 +10,7 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -33,7 +33,7 @@ def adam_(
33
33
  params: list[torch.Tensor] | None = None,
34
34
  grads: list[torch.Tensor] | None = None,
35
35
  ):
36
- """Returns new tensors or updates params in-place."""
36
+ """Returns new tensors."""
37
37
  sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
38
38
  debiased=False,step=step,pow=pow)
39
39
 
@@ -43,11 +43,12 @@ def adam_(
43
43
 
44
44
  exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
45
45
  if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
46
- return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
46
+ return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
47
47
 
48
48
  class Adam(Transform):
49
- """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
50
- pytorch in that debiasing is applied after adding epsilon.
49
+ """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
50
+
51
+ This implementation is identical to :code:`torch.optim.Adam`.
51
52
 
52
53
  Args:
53
54
  beta1 (float, optional): momentum. Defaults to 0.9.
@@ -75,7 +76,7 @@ class Adam(Transform):
75
76
  if inner is not None: self.set_child('inner', inner)
76
77
 
77
78
  @torch.no_grad
78
- def apply(self, tensors, params, grads, loss, states, settings):
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
79
80
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
80
81
 
81
82
  beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
@@ -0,0 +1,110 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+ def adan_(
7
+ g: TensorList,
8
+ g_prev_: TensorList,
9
+ m_: TensorList, # exponential moving average
10
+ v_: TensorList, # exponential moving average of gradient differences
11
+ n_: TensorList, # kinda like squared momentum
12
+ n_prev_: TensorList | None,
13
+ beta1: float | NumberList,
14
+ beta2: float | NumberList,
15
+ beta3: float | NumberList,
16
+ eps: float | NumberList,
17
+ use_n_prev: bool,
18
+ ):
19
+ """Returns new tensors."""
20
+ m_.lerp_(g, 1-beta1)
21
+
22
+ y = g - g_prev_
23
+ v_.lerp_(y, 1-beta2)
24
+
25
+ y.mul_(1-beta2).add_(g)
26
+ n_.mul_(beta3).addcmul_(y, y, 1-beta3)
27
+
28
+ if use_n_prev:
29
+ assert n_prev_ is not None
30
+ ns = n_prev_.clone()
31
+ n_prev_.copy_(n_)
32
+ n_ = ns
33
+
34
+ eta = n_.sqrt().add_(eps).reciprocal_()
35
+ term = m_ + (1-beta2)*v_
36
+ update = eta.mul_(term)
37
+
38
+ g_prev_.copy_(g)
39
+
40
+ return update
41
+
42
+
43
+ class Adan(Transform):
44
+ """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
45
+
46
+ Args:
47
+ beta1 (float, optional): momentum. Defaults to 0.98.
48
+ beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
49
+ beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
50
+ eps (float, optional): epsilon. Defaults to 1e-8.
51
+ use_n_prev (bool, optional):
52
+ whether to use previous gradient differences momentum.
53
+
54
+ Reference:
55
+ Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
56
+ """
57
+ def __init__(
58
+ self,
59
+ beta1: float = 0.98,
60
+ beta2: float = 0.92,
61
+ beta3: float = 0.99,
62
+ eps: float = 1e-8,
63
+ use_n_prev: bool = False,
64
+ ):
65
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,use_n_prev=use_n_prev)
66
+ super().__init__(defaults, uses_grad=False)
67
+
68
+ @torch.no_grad
69
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
+ tensors = TensorList(tensors)
71
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
+
73
+ beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
74
+ s = settings[0]
75
+ use_n_prev = s['use_n_prev']
76
+
77
+ g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
78
+
79
+
80
+ if use_n_prev:
81
+ n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
82
+ else:
83
+ n_prev = None
84
+
85
+ if step == 1:
86
+ # initial values, also runs on restarts
87
+ m.copy_(tensors)
88
+ n.set_(tensors ** 2)
89
+ v.zero_()
90
+ g_prev.copy_(tensors)
91
+ if n_prev is not None: n_prev.set_(tensors ** 2)
92
+
93
+ if step == 2:
94
+ v.set_(tensors - g_prev)
95
+
96
+ update = adan_(
97
+ g=tensors,
98
+ g_prev_=g_prev,
99
+ m_=m,
100
+ v_=v,
101
+ n_=n,
102
+ n_prev_=n_prev,
103
+ beta1=beta1,
104
+ beta2=beta2,
105
+ beta3=beta3,
106
+ eps=eps,
107
+ use_n_prev=use_n_prev,
108
+ )
109
+
110
+ return update
@@ -0,0 +1,57 @@
1
+ import torch
2
+ from ...core import Transform
3
+ from ...utils import TensorList, unpack_dicts, unpack_states
4
+
5
+
6
+ def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
7
+ if f - f_star <= torch.finfo(p[0].dtype).eps: return g
8
+
9
+ g_g = g.dot(g)
10
+ g_gp = g.dot(g_prev)
11
+ num = -(f - f_star) * g.dot(g_prev)
12
+ denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
13
+ m = num/denom
14
+
15
+ h = 2*(f - f_star) / g_g
16
+ return (1 + m) * h * g - m*(p-p_prev)
17
+
18
+
19
+ class AdaptiveHeavyBall(Transform):
20
+ """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
21
+
22
+ This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
23
+
24
+ .. note::
25
+ The step size is determined by the algorithm, so learning rate modules shouldn't be used.
26
+
27
+ Args:
28
+ f_star (int, optional):
29
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
30
+ tol (float, optional):
31
+ tolerance on objective value change.
32
+ """
33
+ def __init__(self, f_star: float = 0):
34
+ defaults = dict(f_star=f_star)
35
+ super().__init__(defaults, uses_grad=False, uses_loss=True)
36
+
37
+ @torch.no_grad
38
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
39
+ assert loss is not None
40
+ tensors = TensorList(tensors)
41
+ setting = settings[0]
42
+ f_star = setting['f_star']
43
+
44
+ f_prev = self.global_state.get('f_prev', None)
45
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
46
+
47
+ if f_prev is None:
48
+ self.global_state['f_prev'] = loss
49
+ h = 2*(loss - f_star) / tensors.dot(tensors)
50
+ return h * tensors
51
+
52
+ update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
53
+
54
+ self.global_state['f_prev'] = loss
55
+ p_prev.copy_(params)
56
+ g_prev.copy_(tensors)
57
+ return update