torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,38 +1,42 @@
1
1
  import math
2
- from collections.abc import Callable
3
2
  from typing import Literal
4
3
 
5
4
  import torch
6
5
 
7
6
  from ...core import Chainable, Module, Target, Transform, apply_transform
8
7
  from ...utils import NumberList, TensorList, as_tensorlist
9
- from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+ from ..functional import debiased_step_size
10
9
 
10
+ def _full_average(hvp: torch.Tensor):
11
+ if hvp.ndim >= 3: # Conv kernel
12
+ return torch.mean(hvp.abs(), dim=[2, *range(3,hvp.ndim)], keepdim=True)
13
+ return hvp
11
14
 
12
15
  def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
13
16
  """averages x over first dimension in blocks"""
14
17
  if enable and x.ndim >= 2:
15
18
  if math.prod(x.shape[1:]) <= 1: return x
19
+ if block_size is None: return _full_average(x)
16
20
  size = x.size(0)
17
- if block_size is None: return x.mean(0, keepdim=True)
18
21
 
19
22
  n_blocks = size // block_size
20
- if n_blocks <= 1: return x.mean(0, keepdim = True)
23
+ if n_blocks <= 1: return x.abs().mean(0, keepdim = True)
21
24
 
22
25
  n_remaining = size - n_blocks * block_size
23
26
  remaining = None
24
27
  if n_remaining > 0:
25
- remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
28
+ remaining = x[-n_remaining:].abs().mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
26
29
  x = x[:-n_remaining]
27
30
 
28
31
  x = x.view(block_size, n_blocks, *x.shape[1:])
29
- x_mean = x.mean(0).repeat_interleave(block_size, 0)
32
+ x_mean = x.abs().mean(0).repeat_interleave(block_size, 0)
30
33
 
31
34
  if remaining is None: return x_mean
32
35
  return torch.cat([x_mean, remaining], 0)
33
36
 
34
37
  return x
35
38
 
39
+
36
40
  def _rademacher_like(tensor, p = 0.5, generator = None):
37
41
  """p is probability of a 1, other values will be -1."""
38
42
  return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
@@ -46,11 +50,11 @@ def adahessian(
46
50
  beta2: float | NumberList,
47
51
  update_freq: int,
48
52
  eps: float | NumberList,
53
+ hessian_power: float | NumberList,
49
54
  step: int,
50
55
  ):
51
56
  # momentum
52
57
  exp_avg_.lerp_(tensors, 1-beta1)
53
- num = exp_avg_ / (1-beta1)
54
58
 
55
59
  # update preconditioner
56
60
  if step % update_freq == 0:
@@ -60,7 +64,9 @@ def adahessian(
60
64
  else:
61
65
  assert D is None
62
66
 
63
- denom = (D_exp_avg_sq_ / (1-beta2)).sqrt_().add_(eps)
67
+
68
+ denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
69
+ num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
64
70
 
65
71
  return num.div_(denom)
66
72
 
@@ -70,16 +76,12 @@ class AdaHessian(Module):
70
76
 
71
77
  This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
72
78
 
73
- .. note::
74
- In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply AdaHessian preconditioning to another module's output.
79
+ Notes:
80
+ - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
75
81
 
76
- .. note::
77
- If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
82
+ - If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
78
83
 
79
- .. note::
80
- This module requires a closure passed to the optimizer step,
81
- as it needs to re-evaluate the loss and gradients for calculating HVPs.
82
- The closure must accept a ``backward`` argument (refer to documentation).
84
+ - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
83
85
 
84
86
  Args:
85
87
  beta1 (float, optional): first momentum. Defaults to 0.9.
@@ -105,7 +107,7 @@ class AdaHessian(Module):
105
107
  more accurate HVP approximation. This requires two extra
106
108
  gradient evaluations.
107
109
  Defaults to "autograd".
108
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
110
+ fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
109
111
  n_samples (int, optional):
110
112
  number of hessian-vector products with random vectors to evaluate each time when updating
111
113
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -113,48 +115,49 @@ class AdaHessian(Module):
113
115
  inner (Chainable | None, optional):
114
116
  Inner module. If this is specified, operations are performed in the following order.
115
117
  1. compute hessian diagonal estimate.
116
- 2. pass inputs to :code:`inner`.
117
- 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
118
-
119
- Examples:
120
- Using AdaHessian:
121
-
122
- .. code-block:: python
123
-
124
- opt = tz.Modular(
125
- model.parameters(),
126
- tz.m.AdaHessian(),
127
- tz.m.LR(0.1)
128
- )
129
-
130
- AdaHessian preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
131
- Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
132
- AdaHessian preconditioning to nesterov momentum (:code:`tz.m.NAG`):
133
-
134
- .. code-block:: python
135
-
136
- opt = tz.Modular(
137
- model.parameters(),
138
- tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
139
- tz.m.LR(0.1)
140
- )
118
+ 2. pass inputs to ``inner``.
119
+ 3. momentum and preconditioning are applied to the ouputs of ``inner``.
120
+
121
+ ## Examples:
122
+
123
+ Using AdaHessian:
124
+
125
+ ```python
126
+ opt = tz.Modular(
127
+ model.parameters(),
128
+ tz.m.AdaHessian(),
129
+ tz.m.LR(0.1)
130
+ )
131
+ ```
132
+
133
+ AdaHessian preconditioner can be applied to any other module by passing it to the ``inner`` argument.
134
+ Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
135
+ AdaHessian preconditioning to nesterov momentum (``tz.m.NAG``):
136
+ ```python
137
+ opt = tz.Modular(
138
+ model.parameters(),
139
+ tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
140
+ tz.m.LR(0.1)
141
+ )
142
+ ```
141
143
 
142
144
  """
143
145
  def __init__(
144
146
  self,
145
147
  beta1: float = 0.9,
146
148
  beta2: float = 0.999,
147
- averaging: bool = False,
148
- block_size: int | None = 9,
149
+ averaging: bool = True,
150
+ block_size: int | None = None,
149
151
  update_freq: int = 1,
150
152
  eps: float = 1e-8,
153
+ hessian_power: float = 1,
151
154
  hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
152
155
  fd_h: float = 1e-3,
153
156
  n_samples = 1,
154
157
  seed: int | None = None,
155
158
  inner: Chainable | None = None
156
159
  ):
157
- defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
160
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
158
161
  super().__init__(defaults)
159
162
 
160
163
  if inner is not None:
@@ -170,14 +173,10 @@ class AdaHessian(Module):
170
173
  n_samples = settings['n_samples']
171
174
 
172
175
  seed = settings['seed']
173
- generator = None
174
- if seed is not None:
175
- if 'generator' not in self.global_state:
176
- self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
177
- generator = self.global_state['generator']
176
+ generator = self.get_generator(params[0].device, seed)
178
177
 
179
- beta1, beta2, eps, averaging, block_size = self.get_settings(params,
180
- 'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
178
+ beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
179
+ 'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
181
180
 
182
181
  exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
183
182
 
@@ -196,6 +195,7 @@ class AdaHessian(Module):
196
195
 
197
196
  Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
198
197
  h=fd_h, normalize=True, retain_grad=i < n_samples-1)
198
+ Hvp = tuple(Hvp)
199
199
 
200
200
  if D is None: D = Hvp
201
201
  else: torch._foreach_add_(D, Hvp)
@@ -218,6 +218,7 @@ class AdaHessian(Module):
218
218
  beta2=beta2,
219
219
  update_freq=update_freq,
220
220
  eps=eps,
221
+ hessian_power=hessian_power,
221
222
  step=step,
222
223
  )
223
224
  return var
@@ -10,9 +10,6 @@ from ..functional import (
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
13
 
17
14
 
18
15
  def adam_(
@@ -9,37 +9,38 @@ def adan_(
9
9
  m_: TensorList, # exponential moving average
10
10
  v_: TensorList, # exponential moving average of gradient differences
11
11
  n_: TensorList, # kinda like squared momentum
12
- n_prev_: TensorList | None,
13
12
  beta1: float | NumberList,
14
13
  beta2: float | NumberList,
15
14
  beta3: float | NumberList,
16
15
  eps: float | NumberList,
17
- use_n_prev: bool,
16
+ step: int,
18
17
  ):
19
- """Returns new tensors."""
20
- m_.lerp_(g, 1-beta1)
18
+ """Returns new tensors"""
19
+ m_.lerp_(g, 1 - beta1)
21
20
 
22
- y = g - g_prev_
23
- v_.lerp_(y, 1-beta2)
21
+ if step == 1:
22
+ term = g
23
+ else:
24
+ diff = g - g_prev_
25
+ v_.lerp_(diff, 1 - beta2)
26
+ term = g + beta2 * diff
24
27
 
25
- y.mul_(1-beta2).add_(g)
26
- n_.mul_(beta3).addcmul_(y, y, 1-beta3)
28
+ n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
27
29
 
28
- if use_n_prev:
29
- assert n_prev_ is not None
30
- ns = n_prev_.clone()
31
- n_prev_.copy_(n_)
32
- n_ = ns
30
+ m = m_ / (1.0 - beta1**step)
31
+ v = v_ / (1.0 - beta2**step)
32
+ n = n_ / (1.0 - beta3**step)
33
33
 
34
- eta = n_.sqrt().add_(eps).reciprocal_()
35
- term = m_ + (1-beta2)*v_
36
- update = eta.mul_(term)
34
+ denom = n.sqrt_().add_(eps)
35
+ num = m + beta2 * v
37
36
 
37
+ update = num.div_(denom)
38
38
  g_prev_.copy_(g)
39
39
 
40
40
  return update
41
41
 
42
42
 
43
+
43
44
  class Adan(Transform):
44
45
  """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
45
46
 
@@ -51,6 +52,13 @@ class Adan(Transform):
51
52
  use_n_prev (bool, optional):
52
53
  whether to use previous gradient differences momentum.
53
54
 
55
+ Example:
56
+ ```python
57
+ opt = tz.Modular(
58
+ model.parameters(),
59
+ tz.m.Adan(),
60
+ tz.m.LR(1e-3),
61
+ )
54
62
  Reference:
55
63
  Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
56
64
  """
@@ -60,9 +68,8 @@ class Adan(Transform):
60
68
  beta2: float = 0.92,
61
69
  beta3: float = 0.99,
62
70
  eps: float = 1e-8,
63
- use_n_prev: bool = False,
64
71
  ):
65
- defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,use_n_prev=use_n_prev)
72
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
66
73
  super().__init__(defaults, uses_grad=False)
67
74
 
68
75
  @torch.no_grad
@@ -71,40 +78,19 @@ class Adan(Transform):
71
78
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
79
 
73
80
  beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
74
- s = settings[0]
75
- use_n_prev = s['use_n_prev']
76
-
77
81
  g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
78
82
 
79
-
80
- if use_n_prev:
81
- n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
82
- else:
83
- n_prev = None
84
-
85
- if step == 1:
86
- # initial values, also runs on restarts
87
- m.copy_(tensors)
88
- n.set_(tensors ** 2)
89
- v.zero_()
90
- g_prev.copy_(tensors)
91
- if n_prev is not None: n_prev.set_(tensors ** 2)
92
-
93
- if step == 2:
94
- v.set_(tensors - g_prev)
95
-
96
83
  update = adan_(
97
84
  g=tensors,
98
85
  g_prev_=g_prev,
99
86
  m_=m,
100
87
  v_=v,
101
88
  n_=n,
102
- n_prev_=n_prev,
103
89
  beta1=beta1,
104
90
  beta2=beta2,
105
91
  beta3=beta3,
106
92
  eps=eps,
107
- use_n_prev=use_n_prev,
93
+ step=step,
108
94
  )
109
95
 
110
96
  return update
@@ -4,7 +4,7 @@ from ...utils import TensorList, unpack_dicts, unpack_states
4
4
 
5
5
 
6
6
  def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
7
- if f - f_star <= torch.finfo(p[0].dtype).eps: return g
7
+ if f - f_star <= torch.finfo(p[0].dtype).tiny * 2: return g
8
8
 
9
9
  g_g = g.dot(g)
10
10
  g_gp = g.dot(g_prev)
@@ -21,14 +21,12 @@ class AdaptiveHeavyBall(Transform):
21
21
 
22
22
  This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
23
23
 
24
- .. note::
24
+ note:
25
25
  The step size is determined by the algorithm, so learning rate modules shouldn't be used.
26
26
 
27
27
  Args:
28
28
  f_star (int, optional):
29
29
  (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
30
- tol (float, optional):
31
- tolerance on objective value change.
32
30
  """
33
31
  def __init__(self, f_star: float = 0):
34
32
  defaults = dict(f_star=f_star)
@@ -38,8 +36,7 @@ class AdaptiveHeavyBall(Transform):
38
36
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
39
37
  assert loss is not None
40
38
  tensors = TensorList(tensors)
41
- setting = settings[0]
42
- f_star = setting['f_star']
39
+ f_star = self.defaults['f_star']
43
40
 
44
41
  f_prev = self.global_state.get('f_prev', None)
45
42
  p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
@@ -0,0 +1,54 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from ...core import Transform
6
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
+
8
+ # i've verified, it is identical to official
9
+ # https://github.com/txping/AEGD/blob/master/aegd.py
10
+ def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
11
+ v = g / (2 * (f + c)**0.5)
12
+ r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
13
+ return 2*eta * r_*v # pyright:ignore[reportReturnType]
14
+
15
+ class AEGD(Transform):
16
+ """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
17
+
18
+ Note:
19
+ AEGD has a learning rate hyperparameter that can't really be removed from the update rule.
20
+ To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
21
+
22
+ Args:
23
+ eta (float, optional): step size. Defaults to 0.1.
24
+ c (float, optional): c. Defaults to 1.
25
+ beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
26
+ eps (float, optional): epsilon. Defaults to 1e-8.
27
+ use_n_prev (bool, optional):
28
+ whether to use previous gradient differences momentum.
29
+ """
30
+ def __init__(
31
+ self,
32
+ lr: float = 0.1,
33
+ c: float = 1,
34
+ ):
35
+ defaults=dict(c=c,lr=lr)
36
+ super().__init__(defaults, uses_loss=True)
37
+
38
+ @torch.no_grad
39
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
+ assert loss is not None
41
+ tensors = TensorList(tensors)
42
+
43
+ c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
44
+ r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
45
+
46
+ update = aegd_(
47
+ f=loss,
48
+ g=tensors,
49
+ r_=r,
50
+ c=c,
51
+ eta=lr,
52
+ )
53
+
54
+ return update
@@ -61,7 +61,7 @@ class ESGD(Module):
61
61
  more accurate HVP approximation. This requires two extra
62
62
  gradient evaluations.
63
63
  Defaults to "autograd".
64
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
64
+ fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
65
65
  n_samples (int, optional):
66
66
  number of hessian-vector products with random vectors to evaluate each time when updating
67
67
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -5,8 +5,12 @@ import warnings
5
5
  import torch
6
6
  from ...core import Chainable, TensorwiseTransform
7
7
 
8
- def lm_adagrad_update(history: deque[torch.Tensor], damping, rdamping):
9
- M = torch.stack(tuple(history), dim=1)# / len(history)
8
+ def lm_adagrad_update(history: deque[torch.Tensor] | torch.Tensor, damping, rdamping):
9
+ if isinstance(history, torch.Tensor):
10
+ M = history
11
+ else:
12
+ M = torch.stack(tuple(history), dim=1)# / len(history)
13
+
10
14
  MTM = M.T @ M
11
15
  if damping != 0:
12
16
  MTM.add_(torch.eye(MTM.size(0), device=MTM.device, dtype=MTM.dtype).mul_(damping))
@@ -58,47 +62,45 @@ class LMAdagrad(TensorwiseTransform):
58
62
  order=2 means gradient differences are used in place of gradients. Higher order uses higher order differences. Defaults to 1.
59
63
  true_damping (bool, optional):
60
64
  If True, damping is added to squared singular values to mimic Adagrad. Defaults to True.
61
- eigh (bool, optional): uses a more efficient way to calculate U and S. Defaults to True.
62
65
  U_beta (float | None, optional): momentum for U (too unstable, don't use). Defaults to None.
63
- S_beta (float | None, optional): momentum for S (too unstable, don't use). Defaults to None.
66
+ L_beta (float | None, optional): momentum for L (too unstable, don't use). Defaults to None.
64
67
  interval (int, optional): Interval between gradients that are added to history (2 means every second gradient is used). Defaults to 1.
65
68
  concat_params (bool, optional): if True, treats all parameters as a single vector, meaning it will also whiten inter-parameters. Defaults to True.
66
69
  inner (Chainable | None, optional): preconditioner will be applied to output of this module. Defaults to None.
67
70
 
68
- Examples:
69
- Limited-memory Adagrad
70
-
71
- .. code-block:: python
72
-
73
- optimizer = tz.Modular(
74
- model.parameters(),
75
- tz.m.LMAdagrad(),
76
- tz.m.LR(0.1)
77
- )
78
-
79
- Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
80
-
81
- .. code-block:: python
82
-
83
- optimizer = tz.Modular(
84
- model.parameters(),
85
- tz.m.LMAdagrad(inner=tz.m.EMA()),
86
- tz.m.Debias(0.9, 0.999),
87
- tz.m.LR(0.01)
88
- )
89
-
90
- Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
91
-
92
- .. code-block:: python
93
-
94
- optimizer = tz.Modular(
95
- model.parameters(),
96
- tz.m.LMAdagrad(inner=tz.m.EMA()),
97
- tz.m.Debias(0.9, 0.999),
98
- tz.m.ClipNormByEMA(max_ema_growth=1.2),
99
- tz.m.LR(0.01)
100
- )
101
-
71
+ ## Examples:
72
+
73
+ Limited-memory Adagrad
74
+
75
+ ```python
76
+ optimizer = tz.Modular(
77
+ model.parameters(),
78
+ tz.m.LMAdagrad(),
79
+ tz.m.LR(0.1)
80
+ )
81
+ ```
82
+ Adam with L-Adagrad preconditioner (for debiasing second beta is 0.999 arbitrarily)
83
+
84
+ ```python
85
+ optimizer = tz.Modular(
86
+ model.parameters(),
87
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
88
+ tz.m.Debias(0.9, 0.999),
89
+ tz.m.LR(0.01)
90
+ )
91
+ ```
92
+
93
+ Stable Adam with L-Adagrad preconditioner (this is what I would recommend)
94
+
95
+ ```python
96
+ optimizer = tz.Modular(
97
+ model.parameters(),
98
+ tz.m.LMAdagrad(inner=tz.m.EMA()),
99
+ tz.m.Debias(0.9, 0.999),
100
+ tz.m.ClipNormByEMA(max_ema_growth=1.2),
101
+ tz.m.LR(0.01)
102
+ )
103
+ ```
102
104
  Reference:
103
105
  Agarwal N. et al. Efficient full-matrix adaptive regularization //International Conference on Machine Learning. – PMLR, 2019. – С. 102-110.
104
106
  """
@@ -143,6 +145,7 @@ class LMAdagrad(TensorwiseTransform):
143
145
  # scaled by parameter differences
144
146
  cur_p = param.clone()
145
147
  cur_g = tensor.clone()
148
+ eps = torch.finfo(cur_p.dtype).tiny * 2
146
149
  for i in range(1, order):
147
150
  if f'prev_g_{i}' not in state:
148
151
  state[f'prev_p_{i}'] = cur_p
@@ -157,7 +160,7 @@ class LMAdagrad(TensorwiseTransform):
157
160
  cur_g = y
158
161
 
159
162
  if i == order - 1:
160
- cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=1e-8) # pylint:disable=not-callable
163
+ cur_g = cur_g / torch.linalg.norm(cur_p).clip(min=eps) # pylint:disable=not-callable
161
164
  history.append(cur_g.view(-1))
162
165
 
163
166
  step = state.get('step', 0)
@@ -1,18 +1,7 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
1
  import torch
5
2
 
6
- from ...core import Module, Target, Transform, apply_transform, Chainable
3
+ from ...core import Transform
7
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
- from ..step_size.lr import lazy_lr
14
- from ..momentum.experimental import sqrt_nag_ema_sq_
15
- from ..momentum.momentum import nag_
16
5
 
17
6
 
18
7
  def mars_correction_(
@@ -35,36 +24,35 @@ class MARSCorrection(Transform):
35
24
  """MARS variance reduction correction.
36
25
 
37
26
  Place any other momentum-based optimizer after this,
38
- make sure :code:`beta` parameter matches with momentum in the optimizer.
27
+ make sure ``beta`` parameter matches with momentum in the optimizer.
39
28
 
40
29
  Args:
41
30
  beta (float, optional): use the same beta as you use in the momentum module. Defaults to 0.9.
42
31
  scaling (float, optional): controls the scale of gradient correction in variance reduction. Defaults to 0.025.
43
32
  max_norm (float, optional): clips norm of corrected gradients, None to disable. Defaults to 1.
44
33
 
45
- Examples:
46
- Mars-AdamW
47
-
48
- .. code-block:: python
49
-
50
- optimizer = tz.Modular(
51
- model.parameters(),
52
- tz.m.MARSCorrection(beta=0.95),
53
- tz.m.Adam(beta1=0.95, beta2=0.99),
54
- tz.m.WeightDecay(1e-3),
55
- tz.m.LR(0.1)
56
- )
57
-
58
- Mars-Lion
59
-
60
- .. code-block:: python
61
-
62
- optimizer = tz.Modular(
63
- model.parameters(),
64
- tz.m.MARSCorrection(beta=0.9),
65
- tz.m.Lion(beta1=0.9),
66
- tz.m.LR(0.1)
67
- )
34
+ ## Examples:
35
+
36
+ Mars-AdamW
37
+ ```python
38
+ optimizer = tz.Modular(
39
+ model.parameters(),
40
+ tz.m.MARSCorrection(beta=0.95),
41
+ tz.m.Adam(beta1=0.95, beta2=0.99),
42
+ tz.m.WeightDecay(1e-3),
43
+ tz.m.LR(0.1)
44
+ )
45
+ ```
46
+
47
+ Mars-Lion
48
+ ```python
49
+ optimizer = tz.Modular(
50
+ model.parameters(),
51
+ tz.m.MARSCorrection(beta=0.9),
52
+ tz.m.Lion(beta1=0.9),
53
+ tz.m.LR(0.1)
54
+ )
55
+ ```
68
56
 
69
57
  """
70
58
  def __init__(