torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Target, Transform, apply_transform
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
10
+
11
+
12
+ def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
13
+ """averages x over first dimension in blocks"""
14
+ if enable and x.ndim >= 2:
15
+ if math.prod(x.shape[1:]) <= 1: return x
16
+ size = x.size(0)
17
+ if block_size is None: return x.mean(0, keepdim=True)
18
+
19
+ n_blocks = size // block_size
20
+ if n_blocks <= 1: return x.mean(0, keepdim = True)
21
+
22
+ n_remaining = size - n_blocks * block_size
23
+ remaining = None
24
+ if n_remaining > 0:
25
+ remaining = x[-n_remaining:].mean(0, keepdim=True).repeat_interleave(n_remaining, 0)
26
+ x = x[:-n_remaining]
27
+
28
+ x = x.view(block_size, n_blocks, *x.shape[1:])
29
+ x_mean = x.mean(0).repeat_interleave(block_size, 0)
30
+
31
+ if remaining is None: return x_mean
32
+ return torch.cat([x_mean, remaining], 0)
33
+
34
+ return x
35
+
36
+ def _rademacher_like(tensor, p = 0.5, generator = None):
37
+ """p is probability of a 1, other values will be -1."""
38
+ return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
39
+
40
+ def adahessian(
41
+ tensors: TensorList,
42
+ D: TensorList | None,
43
+ exp_avg_: TensorList,
44
+ D_exp_avg_sq_: TensorList,
45
+ beta1: float | NumberList,
46
+ beta2: float | NumberList,
47
+ update_freq: int,
48
+ eps: float | NumberList,
49
+ step: int,
50
+ ):
51
+ # momentum
52
+ exp_avg_.lerp_(tensors, 1-beta1)
53
+ num = exp_avg_ / (1-beta1)
54
+
55
+ # update preconditioner
56
+ if step % update_freq == 0:
57
+ assert D is not None
58
+ D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
59
+
60
+ else:
61
+ assert D is None
62
+
63
+ denom = (D_exp_avg_sq_ / (1-beta2)).sqrt_().add_(eps)
64
+
65
+ return num.div_(denom)
66
+
67
+
68
+ class AdaHessian(Module):
69
+ """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
70
+
71
+ This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
72
+
73
+ .. note::
74
+ In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply AdaHessian preconditioning to another module's output.
75
+
76
+ .. note::
77
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
78
+
79
+ .. note::
80
+ This module requires a closure passed to the optimizer step,
81
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
82
+ The closure must accept a ``backward`` argument (refer to documentation).
83
+
84
+ Args:
85
+ beta1 (float, optional): first momentum. Defaults to 0.9.
86
+ beta2 (float, optional): second momentum for squared hessian diagonal estimates. Defaults to 0.999.
87
+ averaging (bool, optional):
88
+ whether to enable block diagonal averaging over 1st dimension on parameters that have 2+ dimensions.
89
+ This can be set per-parameter in param groups.
90
+ block_size (int, optional):
91
+ size of block in the block-diagonal averaging.
92
+ update_freq (int, optional):
93
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
94
+ This value can be increased to reduce computational cost. Defaults to 1.
95
+ eps (float, optional):
96
+ division stability epsilon. Defaults to 1e-8.
97
+ hvp_method (str, optional):
98
+ Determines how Hessian-vector products are evaluated.
99
+
100
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
101
+ This requires creating a graph for the gradient.
102
+ - ``"forward"``: Use a forward finite difference formula to
103
+ approximate the HVP. This requires one extra gradient evaluation.
104
+ - ``"central"``: Use a central finite difference formula for a
105
+ more accurate HVP approximation. This requires two extra
106
+ gradient evaluations.
107
+ Defaults to "autograd".
108
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
109
+ n_samples (int, optional):
110
+ number of hessian-vector products with random vectors to evaluate each time when updating
111
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
112
+ seed (int | None, optional): seed for random vectors. Defaults to None.
113
+ inner (Chainable | None, optional):
114
+ Inner module. If this is specified, operations are performed in the following order.
115
+ 1. compute hessian diagonal estimate.
116
+ 2. pass inputs to :code:`inner`.
117
+ 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
118
+
119
+ Examples:
120
+ Using AdaHessian:
121
+
122
+ .. code-block:: python
123
+
124
+ opt = tz.Modular(
125
+ model.parameters(),
126
+ tz.m.AdaHessian(),
127
+ tz.m.LR(0.1)
128
+ )
129
+
130
+ AdaHessian preconditioner can be applied to any other module by passing it to the :code:`inner` argument.
131
+ Turn off AdaHessian's first momentum to get just the preconditioning. Here is an example of applying
132
+ AdaHessian preconditioning to nesterov momentum (:code:`tz.m.NAG`):
133
+
134
+ .. code-block:: python
135
+
136
+ opt = tz.Modular(
137
+ model.parameters(),
138
+ tz.m.AdaHessian(beta1=0, inner=tz.m.NAG(0.9)),
139
+ tz.m.LR(0.1)
140
+ )
141
+
142
+ """
143
+ def __init__(
144
+ self,
145
+ beta1: float = 0.9,
146
+ beta2: float = 0.999,
147
+ averaging: bool = False,
148
+ block_size: int | None = 9,
149
+ update_freq: int = 1,
150
+ eps: float = 1e-8,
151
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
152
+ fd_h: float = 1e-3,
153
+ n_samples = 1,
154
+ seed: int | None = None,
155
+ inner: Chainable | None = None
156
+ ):
157
+ defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
158
+ super().__init__(defaults)
159
+
160
+ if inner is not None:
161
+ self.set_child('inner', inner)
162
+
163
+ @torch.no_grad
164
+ def step(self, var):
165
+ params = var.params
166
+ settings = self.settings[params[0]]
167
+ hvp_method = settings['hvp_method']
168
+ fd_h = settings['fd_h']
169
+ update_freq = settings['update_freq']
170
+ n_samples = settings['n_samples']
171
+
172
+ seed = settings['seed']
173
+ generator = None
174
+ if seed is not None:
175
+ if 'generator' not in self.global_state:
176
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
177
+ generator = self.global_state['generator']
178
+
179
+ beta1, beta2, eps, averaging, block_size = self.get_settings(params,
180
+ 'beta1', 'beta2', 'eps', 'averaging', 'block_size', cls=NumberList)
181
+
182
+ exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
183
+
184
+ step = self.global_state.get('step', 0)
185
+ self.global_state['step'] = step + 1
186
+
187
+ closure = var.closure
188
+ assert closure is not None
189
+
190
+ D = None
191
+ if step % update_freq == 0:
192
+
193
+ rgrad=None
194
+ for i in range(n_samples):
195
+ u = [_rademacher_like(p, generator=generator) for p in params]
196
+
197
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
198
+ h=fd_h, normalize=True, retain_grad=i < n_samples-1)
199
+
200
+ if D is None: D = Hvp
201
+ else: torch._foreach_add_(D, Hvp)
202
+
203
+ assert D is not None
204
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
205
+
206
+ D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
207
+
208
+ update = var.get_update()
209
+ if 'inner' in self.children:
210
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
211
+
212
+ var.update = adahessian(
213
+ tensors=TensorList(update),
214
+ D=TensorList(D) if D is not None else None,
215
+ exp_avg_=exp_avg,
216
+ D_exp_avg_sq_=D_exp_avg_sq,
217
+ beta1=beta1,
218
+ beta2=beta2,
219
+ update_freq=update_freq,
220
+ eps=eps,
221
+ step=step,
222
+ )
223
+ return var
@@ -3,14 +3,14 @@ from functools import partial
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
8
  from ..functional import (
9
9
  debias, debiased_step_size,
10
10
  ema_,
11
11
  sqrt_ema_sq_,
12
12
  )
13
- from ..lr.lr import lazy_lr
13
+ from ..step_size.lr import lazy_lr
14
14
  from ..momentum.experimental import sqrt_nag_ema_sq_
15
15
  from ..momentum.momentum import nag_
16
16
 
@@ -27,26 +27,28 @@ def adam_(
27
27
  pow: float = 2,
28
28
  debiased: bool = True,
29
29
  max_exp_avg_sq_: TensorList | None = None,
30
- params_: TensorList | None = None,
31
- ):
32
- """Returns new tensors or updates params in-place."""
33
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
34
30
 
31
+ # inner args
32
+ inner: Module | None = None,
33
+ params: list[torch.Tensor] | None = None,
34
+ grads: list[torch.Tensor] | None = None,
35
+ ):
36
+ """Returns new tensors."""
35
37
  sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
36
38
  debiased=False,step=step,pow=pow)
37
39
 
38
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
40
+ if inner is not None:
41
+ assert params is not None
42
+ tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
39
43
 
40
- # params is None, return update
41
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
44
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
45
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
46
+ return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
42
47
 
43
- # update params in-place
44
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
45
- return None
48
+ class Adam(Transform):
49
+ """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
46
50
 
47
- class Adam(Module):
48
- """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
49
- pytorch in that debiasing is applied after adding epsilon.
51
+ This implementation is identical to :code:`torch.optim.Adam`.
50
52
 
51
53
  Args:
52
54
  beta1 (float, optional): momentum. Defaults to 0.9.
@@ -66,36 +68,29 @@ class Adam(Module):
66
68
  alpha: float = 1.,
67
69
  pow: float = 2,
68
70
  debiased: bool = True,
71
+ inner: Chainable | None = None
69
72
  ):
70
73
  defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
71
- super().__init__(defaults)
72
- self.getter = itemgetter('amsgrad','pow','debiased')
74
+ super().__init__(defaults, uses_grad=False)
75
+
76
+ if inner is not None: self.set_child('inner', inner)
73
77
 
74
78
  @torch.no_grad
75
- def step(self, vars):
79
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
76
80
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
81
 
78
- beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
79
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
82
+ beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
83
+ amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
80
84
 
81
85
  if amsgrad:
82
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
86
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
83
87
  else:
84
- exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
88
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
85
89
  max_exp_avg_sq = None
86
90
 
87
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
88
- if vars.is_last:
89
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
90
- passed_params = TensorList(vars.params)
91
- vars.stop = True
92
- vars.skip_update = True
93
-
94
- else:
95
- passed_params = None
96
91
 
97
- vars.update = adam_(
98
- tensors=TensorList(vars.get_update()),
92
+ return adam_(
93
+ tensors=TensorList(tensors),
99
94
  exp_avg_=exp_avg,
100
95
  exp_avg_sq_=exp_avg_sq,
101
96
  alpha=alpha,
@@ -106,7 +101,10 @@ class Adam(Module):
106
101
  pow=pow,
107
102
  debiased=debiased,
108
103
  max_exp_avg_sq_=max_exp_avg_sq,
109
- params_=passed_params,
110
- )
111
104
 
112
- return vars
105
+ # inner args
106
+ inner=self.children.get("inner", None),
107
+ params=params,
108
+ grads=grads,
109
+
110
+ )
@@ -0,0 +1,110 @@
1
+ import torch
2
+
3
+ from ...core import Transform
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
+
6
+ def adan_(
7
+ g: TensorList,
8
+ g_prev_: TensorList,
9
+ m_: TensorList, # exponential moving average
10
+ v_: TensorList, # exponential moving average of gradient differences
11
+ n_: TensorList, # kinda like squared momentum
12
+ n_prev_: TensorList | None,
13
+ beta1: float | NumberList,
14
+ beta2: float | NumberList,
15
+ beta3: float | NumberList,
16
+ eps: float | NumberList,
17
+ use_n_prev: bool,
18
+ ):
19
+ """Returns new tensors."""
20
+ m_.lerp_(g, 1-beta1)
21
+
22
+ y = g - g_prev_
23
+ v_.lerp_(y, 1-beta2)
24
+
25
+ y.mul_(1-beta2).add_(g)
26
+ n_.mul_(beta3).addcmul_(y, y, 1-beta3)
27
+
28
+ if use_n_prev:
29
+ assert n_prev_ is not None
30
+ ns = n_prev_.clone()
31
+ n_prev_.copy_(n_)
32
+ n_ = ns
33
+
34
+ eta = n_.sqrt().add_(eps).reciprocal_()
35
+ term = m_ + (1-beta2)*v_
36
+ update = eta.mul_(term)
37
+
38
+ g_prev_.copy_(g)
39
+
40
+ return update
41
+
42
+
43
+ class Adan(Transform):
44
+ """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
45
+
46
+ Args:
47
+ beta1 (float, optional): momentum. Defaults to 0.98.
48
+ beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
49
+ beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
50
+ eps (float, optional): epsilon. Defaults to 1e-8.
51
+ use_n_prev (bool, optional):
52
+ whether to use previous gradient differences momentum.
53
+
54
+ Reference:
55
+ Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
56
+ """
57
+ def __init__(
58
+ self,
59
+ beta1: float = 0.98,
60
+ beta2: float = 0.92,
61
+ beta3: float = 0.99,
62
+ eps: float = 1e-8,
63
+ use_n_prev: bool = False,
64
+ ):
65
+ defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps,use_n_prev=use_n_prev)
66
+ super().__init__(defaults, uses_grad=False)
67
+
68
+ @torch.no_grad
69
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
70
+ tensors = TensorList(tensors)
71
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
+
73
+ beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
74
+ s = settings[0]
75
+ use_n_prev = s['use_n_prev']
76
+
77
+ g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
78
+
79
+
80
+ if use_n_prev:
81
+ n_prev = unpack_states(states, tensors, 'n_prev', cls=TensorList)
82
+ else:
83
+ n_prev = None
84
+
85
+ if step == 1:
86
+ # initial values, also runs on restarts
87
+ m.copy_(tensors)
88
+ n.set_(tensors ** 2)
89
+ v.zero_()
90
+ g_prev.copy_(tensors)
91
+ if n_prev is not None: n_prev.set_(tensors ** 2)
92
+
93
+ if step == 2:
94
+ v.set_(tensors - g_prev)
95
+
96
+ update = adan_(
97
+ g=tensors,
98
+ g_prev_=g_prev,
99
+ m_=m,
100
+ v_=v,
101
+ n_=n,
102
+ n_prev_=n_prev,
103
+ beta1=beta1,
104
+ beta2=beta2,
105
+ beta3=beta3,
106
+ eps=eps,
107
+ use_n_prev=use_n_prev,
108
+ )
109
+
110
+ return update
@@ -0,0 +1,57 @@
1
+ import torch
2
+ from ...core import Transform
3
+ from ...utils import TensorList, unpack_dicts, unpack_states
4
+
5
+
6
+ def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p: TensorList, p_prev: TensorList):
7
+ if f - f_star <= torch.finfo(p[0].dtype).eps: return g
8
+
9
+ g_g = g.dot(g)
10
+ g_gp = g.dot(g_prev)
11
+ num = -(f - f_star) * g.dot(g_prev)
12
+ denom = (f_prev - f_star) * g_g + (f - f_star) * g_gp
13
+ m = num/denom
14
+
15
+ h = 2*(f - f_star) / g_g
16
+ return (1 + m) * h * g - m*(p-p_prev)
17
+
18
+
19
+ class AdaptiveHeavyBall(Transform):
20
+ """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
21
+
22
+ This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
23
+
24
+ .. note::
25
+ The step size is determined by the algorithm, so learning rate modules shouldn't be used.
26
+
27
+ Args:
28
+ f_star (int, optional):
29
+ (estimated) minimal possible value of the objective function (lowest possible loss). Defaults to 0.
30
+ tol (float, optional):
31
+ tolerance on objective value change.
32
+ """
33
+ def __init__(self, f_star: float = 0):
34
+ defaults = dict(f_star=f_star)
35
+ super().__init__(defaults, uses_grad=False, uses_loss=True)
36
+
37
+ @torch.no_grad
38
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
39
+ assert loss is not None
40
+ tensors = TensorList(tensors)
41
+ setting = settings[0]
42
+ f_star = setting['f_star']
43
+
44
+ f_prev = self.global_state.get('f_prev', None)
45
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
46
+
47
+ if f_prev is None:
48
+ self.global_state['f_prev'] = loss
49
+ h = 2*(loss - f_star) / tensors.dot(tensors)
50
+ return h * tensors
51
+
52
+ update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
53
+
54
+ self.global_state['f_prev'] = loss
55
+ p_prev.copy_(params)
56
+ g_prev.copy_(tensors)
57
+ return update
@@ -0,0 +1,171 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Chainable, Module, Target, Transform, apply_transform
8
+ from ...utils import NumberList, TensorList, as_tensorlist
9
+
10
+
11
+ def esgd_(
12
+ tensors_: TensorList,
13
+ D: TensorList | None,
14
+ D_sq_acc_: TensorList,
15
+ damping: float | NumberList,
16
+ update_freq: int,
17
+ step: int,
18
+ i: int,
19
+ ):
20
+ # update preconditioner
21
+ if step % update_freq == 0:
22
+ assert D is not None
23
+ D_sq_acc_.addcmul_(D, D)
24
+ i += 1
25
+ else:
26
+ assert D is None
27
+
28
+ denom = (D_sq_acc_ / max(i, 1)).sqrt_().add_(damping)
29
+ return tensors_.div_(denom), i
30
+
31
+
32
+ class ESGD(Module):
33
+ """Equilibrated Gradient Descent (https://arxiv.org/abs/1502.04390)
34
+
35
+ This is similar to Adagrad, but the accumulates squared randomized hessian diagonal estimates instead of squared gradients.
36
+
37
+ .. note::
38
+ In most cases Adagrad should be the first module in the chain because it relies on autograd. Use the :code:`inner` argument if you wish to apply Adagrad preconditioning to another module's output.
39
+
40
+ .. note::
41
+ If you are using gradient estimators or reformulations, set :code:`hvp_method` to "forward" or "central".
42
+
43
+ .. note::
44
+ This module requires a closure passed to the optimizer step,
45
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
46
+ The closure must accept a ``backward`` argument (refer to documentation).
47
+
48
+ Args:
49
+ damping (float, optional): added to denominator for stability. Defaults to 1e-4.
50
+ update_freq (int, optional):
51
+ frequency of updating hessian diagonal estimate via a hessian-vector product.
52
+ This value can be increased to reduce computational cost. Defaults to 20.
53
+ hvp_method (str, optional):
54
+ Determines how Hessian-vector products are evaluated.
55
+
56
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
57
+ This requires creating a graph for the gradient.
58
+ - ``"forward"``: Use a forward finite difference formula to
59
+ approximate the HVP. This requires one extra gradient evaluation.
60
+ - ``"central"``: Use a central finite difference formula for a
61
+ more accurate HVP approximation. This requires two extra
62
+ gradient evaluations.
63
+ Defaults to "autograd".
64
+ h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
65
+ n_samples (int, optional):
66
+ number of hessian-vector products with random vectors to evaluate each time when updating
67
+ the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
68
+ seed (int | None, optional): seed for random vectors. Defaults to None.
69
+ inner (Chainable | None, optional):
70
+ Inner module. If this is specified, operations are performed in the following order.
71
+ 1. compute hessian diagonal estimate.
72
+ 2. pass inputs to :code:`inner`.
73
+ 3. momentum and preconditioning are applied to the ouputs of :code:`inner`.
74
+
75
+ Examples:
76
+ Using ESGD:
77
+
78
+ .. code-block:: python
79
+
80
+ opt = tz.Modular(
81
+ model.parameters(),
82
+ tz.m.ESGD(),
83
+ tz.m.LR(0.1)
84
+ )
85
+
86
+ ESGD preconditioner can be applied to any other module by passing it to the :code:`inner` argument. Here is an example of applying
87
+ ESGD preconditioning to nesterov momentum (:code:`tz.m.NAG`):
88
+
89
+ .. code-block:: python
90
+
91
+ opt = tz.Modular(
92
+ model.parameters(),
93
+ tz.m.ESGD(beta1=0, inner=tz.m.NAG(0.9)),
94
+ tz.m.LR(0.1)
95
+ )
96
+
97
+ """
98
+ def __init__(
99
+ self,
100
+ damping: float = 1e-4,
101
+ update_freq: int = 20,
102
+ hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
103
+ fd_h: float = 1e-3,
104
+ n_samples = 1,
105
+ seed: int | None = None,
106
+ inner: Chainable | None = None
107
+ ):
108
+ defaults = dict(damping=damping, update_freq=update_freq, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
109
+ super().__init__(defaults)
110
+
111
+ if inner is not None:
112
+ self.set_child('inner', inner)
113
+
114
+ @torch.no_grad
115
+ def step(self, var):
116
+ params = var.params
117
+ settings = self.settings[params[0]]
118
+ hvp_method = settings['hvp_method']
119
+ fd_h = settings['fd_h']
120
+ update_freq = settings['update_freq']
121
+ n_samples = settings['n_samples']
122
+
123
+ seed = settings['seed']
124
+ generator = None
125
+ if seed is not None:
126
+ if 'generator' not in self.global_state:
127
+ self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
128
+ generator = self.global_state['generator']
129
+
130
+ damping = self.get_settings(params, 'damping', cls=NumberList)
131
+ D_sq_acc = self.get_state(params, 'D_sq_acc', cls=TensorList)
132
+ i = self.global_state.get('i', 0)
133
+
134
+ step = self.global_state.get('step', 0)
135
+ self.global_state['step'] = step + 1
136
+
137
+ closure = var.closure
138
+ assert closure is not None
139
+
140
+ D = None
141
+ if step % update_freq == 0:
142
+
143
+ rgrad=None
144
+ for j in range(n_samples):
145
+ u = [torch.randn(p.size(), generator=generator, device=p.device, dtype=p.dtype) for p in params]
146
+
147
+ Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
148
+ h=fd_h, normalize=True, retain_grad=j < n_samples-1)
149
+
150
+ if D is None: D = Hvp
151
+ else: torch._foreach_add_(D, Hvp)
152
+
153
+ assert D is not None
154
+ if n_samples > 1: torch._foreach_div_(D, n_samples)
155
+
156
+ D = TensorList(D)
157
+
158
+ update = var.get_update()
159
+ if 'inner' in self.children:
160
+ update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
161
+
162
+ var.update, self.global_state['i'] = esgd_(
163
+ tensors_=TensorList(update),
164
+ D=TensorList(D) if D is not None else None,
165
+ D_sq_acc_=D_sq_acc,
166
+ damping=damping,
167
+ update_freq=update_freq,
168
+ step=step,
169
+ i=i,
170
+ )
171
+ return var