torchzero 0.3.15__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. tests/test_identical.py +2 -2
  2. tests/test_module_autograd.py +586 -0
  3. tests/test_objective.py +188 -0
  4. tests/test_opts.py +43 -33
  5. tests/test_tensorlist.py +0 -8
  6. tests/test_utils_optimizer.py +0 -1
  7. torchzero/__init__.py +1 -1
  8. torchzero/core/__init__.py +7 -4
  9. torchzero/core/chain.py +20 -23
  10. torchzero/core/functional.py +90 -24
  11. torchzero/core/modular.py +48 -52
  12. torchzero/core/module.py +130 -50
  13. torchzero/core/objective.py +948 -0
  14. torchzero/core/reformulation.py +55 -24
  15. torchzero/core/transform.py +261 -367
  16. torchzero/linalg/__init__.py +10 -0
  17. torchzero/linalg/eigh.py +34 -0
  18. torchzero/linalg/linalg_utils.py +14 -0
  19. torchzero/{utils/linalg → linalg}/linear_operator.py +99 -49
  20. torchzero/linalg/matrix_power.py +28 -0
  21. torchzero/linalg/orthogonalize.py +95 -0
  22. torchzero/{utils/linalg → linalg}/qr.py +4 -2
  23. torchzero/{utils/linalg → linalg}/solve.py +76 -88
  24. torchzero/linalg/svd.py +20 -0
  25. torchzero/linalg/torch_linalg.py +168 -0
  26. torchzero/modules/adaptive/__init__.py +1 -1
  27. torchzero/modules/adaptive/adagrad.py +163 -213
  28. torchzero/modules/adaptive/adahessian.py +74 -103
  29. torchzero/modules/adaptive/adam.py +53 -76
  30. torchzero/modules/adaptive/adan.py +49 -30
  31. torchzero/modules/adaptive/adaptive_heavyball.py +11 -6
  32. torchzero/modules/adaptive/aegd.py +12 -12
  33. torchzero/modules/adaptive/esgd.py +98 -119
  34. torchzero/modules/adaptive/lion.py +5 -10
  35. torchzero/modules/adaptive/lmadagrad.py +87 -32
  36. torchzero/modules/adaptive/mars.py +5 -5
  37. torchzero/modules/adaptive/matrix_momentum.py +47 -51
  38. torchzero/modules/adaptive/msam.py +70 -52
  39. torchzero/modules/adaptive/muon.py +59 -124
  40. torchzero/modules/adaptive/natural_gradient.py +33 -28
  41. torchzero/modules/adaptive/orthograd.py +11 -15
  42. torchzero/modules/adaptive/rmsprop.py +83 -75
  43. torchzero/modules/adaptive/rprop.py +48 -47
  44. torchzero/modules/adaptive/sam.py +55 -45
  45. torchzero/modules/adaptive/shampoo.py +123 -129
  46. torchzero/modules/adaptive/soap.py +207 -143
  47. torchzero/modules/adaptive/sophia_h.py +106 -130
  48. torchzero/modules/clipping/clipping.py +15 -18
  49. torchzero/modules/clipping/ema_clipping.py +31 -25
  50. torchzero/modules/clipping/growth_clipping.py +14 -17
  51. torchzero/modules/conjugate_gradient/cg.py +26 -37
  52. torchzero/modules/experimental/__init__.py +2 -6
  53. torchzero/modules/experimental/coordinate_momentum.py +36 -0
  54. torchzero/modules/experimental/curveball.py +25 -41
  55. torchzero/modules/experimental/gradmin.py +2 -2
  56. torchzero/modules/experimental/higher_order_newton.py +14 -40
  57. torchzero/modules/experimental/newton_solver.py +22 -53
  58. torchzero/modules/experimental/newtonnewton.py +15 -12
  59. torchzero/modules/experimental/reduce_outward_lr.py +7 -7
  60. torchzero/modules/experimental/scipy_newton_cg.py +21 -24
  61. torchzero/modules/experimental/spsa1.py +3 -3
  62. torchzero/modules/experimental/structural_projections.py +1 -4
  63. torchzero/modules/functional.py +1 -1
  64. torchzero/modules/grad_approximation/forward_gradient.py +7 -7
  65. torchzero/modules/grad_approximation/grad_approximator.py +23 -16
  66. torchzero/modules/grad_approximation/rfdm.py +20 -17
  67. torchzero/modules/least_squares/gn.py +90 -42
  68. torchzero/modules/line_search/backtracking.py +2 -2
  69. torchzero/modules/line_search/line_search.py +32 -32
  70. torchzero/modules/line_search/strong_wolfe.py +2 -2
  71. torchzero/modules/misc/debug.py +12 -12
  72. torchzero/modules/misc/escape.py +10 -10
  73. torchzero/modules/misc/gradient_accumulation.py +10 -78
  74. torchzero/modules/misc/homotopy.py +16 -8
  75. torchzero/modules/misc/misc.py +120 -122
  76. torchzero/modules/misc/multistep.py +50 -48
  77. torchzero/modules/misc/regularization.py +49 -44
  78. torchzero/modules/misc/split.py +30 -28
  79. torchzero/modules/misc/switch.py +37 -32
  80. torchzero/modules/momentum/averaging.py +14 -14
  81. torchzero/modules/momentum/cautious.py +34 -28
  82. torchzero/modules/momentum/momentum.py +11 -11
  83. torchzero/modules/ops/__init__.py +4 -4
  84. torchzero/modules/ops/accumulate.py +21 -21
  85. torchzero/modules/ops/binary.py +67 -66
  86. torchzero/modules/ops/higher_level.py +19 -19
  87. torchzero/modules/ops/multi.py +44 -41
  88. torchzero/modules/ops/reduce.py +26 -23
  89. torchzero/modules/ops/unary.py +53 -53
  90. torchzero/modules/ops/utility.py +47 -46
  91. torchzero/modules/projections/galore.py +1 -1
  92. torchzero/modules/projections/projection.py +43 -43
  93. torchzero/modules/quasi_newton/damping.py +1 -1
  94. torchzero/modules/quasi_newton/lbfgs.py +7 -7
  95. torchzero/modules/quasi_newton/lsr1.py +7 -7
  96. torchzero/modules/quasi_newton/quasi_newton.py +10 -10
  97. torchzero/modules/quasi_newton/sg2.py +19 -19
  98. torchzero/modules/restarts/restars.py +26 -24
  99. torchzero/modules/second_order/__init__.py +2 -2
  100. torchzero/modules/second_order/ifn.py +31 -62
  101. torchzero/modules/second_order/inm.py +49 -53
  102. torchzero/modules/second_order/multipoint.py +40 -80
  103. torchzero/modules/second_order/newton.py +57 -90
  104. torchzero/modules/second_order/newton_cg.py +102 -154
  105. torchzero/modules/second_order/nystrom.py +157 -177
  106. torchzero/modules/second_order/rsn.py +106 -96
  107. torchzero/modules/smoothing/laplacian.py +13 -12
  108. torchzero/modules/smoothing/sampling.py +11 -10
  109. torchzero/modules/step_size/adaptive.py +23 -23
  110. torchzero/modules/step_size/lr.py +15 -15
  111. torchzero/modules/termination/termination.py +32 -30
  112. torchzero/modules/trust_region/cubic_regularization.py +2 -2
  113. torchzero/modules/trust_region/levenberg_marquardt.py +25 -28
  114. torchzero/modules/trust_region/trust_cg.py +1 -1
  115. torchzero/modules/trust_region/trust_region.py +27 -22
  116. torchzero/modules/variance_reduction/svrg.py +21 -18
  117. torchzero/modules/weight_decay/__init__.py +2 -1
  118. torchzero/modules/weight_decay/reinit.py +83 -0
  119. torchzero/modules/weight_decay/weight_decay.py +12 -13
  120. torchzero/modules/wrappers/optim_wrapper.py +10 -10
  121. torchzero/modules/zeroth_order/cd.py +9 -6
  122. torchzero/optim/root.py +3 -3
  123. torchzero/optim/utility/split.py +2 -1
  124. torchzero/optim/wrappers/directsearch.py +27 -63
  125. torchzero/optim/wrappers/fcmaes.py +14 -35
  126. torchzero/optim/wrappers/mads.py +11 -31
  127. torchzero/optim/wrappers/moors.py +66 -0
  128. torchzero/optim/wrappers/nevergrad.py +4 -4
  129. torchzero/optim/wrappers/nlopt.py +31 -25
  130. torchzero/optim/wrappers/optuna.py +6 -13
  131. torchzero/optim/wrappers/pybobyqa.py +124 -0
  132. torchzero/optim/wrappers/scipy/__init__.py +7 -0
  133. torchzero/optim/wrappers/scipy/basin_hopping.py +117 -0
  134. torchzero/optim/wrappers/scipy/brute.py +48 -0
  135. torchzero/optim/wrappers/scipy/differential_evolution.py +80 -0
  136. torchzero/optim/wrappers/scipy/direct.py +69 -0
  137. torchzero/optim/wrappers/scipy/dual_annealing.py +115 -0
  138. torchzero/optim/wrappers/scipy/experimental.py +141 -0
  139. torchzero/optim/wrappers/scipy/minimize.py +151 -0
  140. torchzero/optim/wrappers/scipy/sgho.py +111 -0
  141. torchzero/optim/wrappers/wrapper.py +121 -0
  142. torchzero/utils/__init__.py +7 -25
  143. torchzero/utils/compile.py +2 -2
  144. torchzero/utils/derivatives.py +93 -69
  145. torchzero/utils/optimizer.py +4 -77
  146. torchzero/utils/python_tools.py +31 -0
  147. torchzero/utils/tensorlist.py +11 -5
  148. torchzero/utils/thoad_tools.py +68 -0
  149. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/METADATA +1 -1
  150. torchzero-0.4.0.dist-info/RECORD +191 -0
  151. tests/test_vars.py +0 -185
  152. torchzero/core/var.py +0 -376
  153. torchzero/modules/experimental/momentum.py +0 -160
  154. torchzero/optim/wrappers/scipy.py +0 -572
  155. torchzero/utils/linalg/__init__.py +0 -12
  156. torchzero/utils/linalg/matrix_funcs.py +0 -87
  157. torchzero/utils/linalg/orthogonalize.py +0 -12
  158. torchzero/utils/linalg/svd.py +0 -20
  159. torchzero/utils/ops.py +0 -10
  160. torchzero-0.3.15.dist-info/RECORD +0 -175
  161. /torchzero/{utils/linalg → linalg}/benchmark.py +0 -0
  162. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/WHEEL +0 -0
  163. {torchzero-0.3.15.dist-info → torchzero-0.4.0.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,8 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, Target, Transform, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist
8
- from ..functional import debiased_step_size
6
+ from ...core import Chainable, Transform, HVPMethod
7
+ from ...utils import NumberList, TensorList, Distributions, unpack_dicts, unpack_states
9
8
 
10
9
  def _full_average(hvp: torch.Tensor):
11
10
  if hvp.ndim >= 3: # Conv kernel
@@ -37,41 +36,7 @@ def _block_average(x: torch.Tensor, block_size: int | None, enable: bool):
37
36
  return x
38
37
 
39
38
 
40
- def _rademacher_like(tensor, p = 0.5, generator = None):
41
- """p is probability of a 1, other values will be -1."""
42
- return torch.bernoulli(torch.full_like(tensor, p), generator = generator).mul_(2).sub_(1)
43
-
44
- def adahessian(
45
- tensors: TensorList,
46
- D: TensorList | None,
47
- exp_avg_: TensorList,
48
- D_exp_avg_sq_: TensorList,
49
- beta1: float | NumberList,
50
- beta2: float | NumberList,
51
- update_freq: int,
52
- eps: float | NumberList,
53
- hessian_power: float | NumberList,
54
- step: int,
55
- ):
56
- # momentum
57
- exp_avg_.lerp_(tensors, 1-beta1)
58
-
59
- # update preconditioner
60
- if step % update_freq == 0:
61
- assert D is not None
62
- D_exp_avg_sq_.mul_(beta2).addcmul_(D, D, 1-beta2)
63
-
64
- else:
65
- assert D is None
66
-
67
-
68
- denom = D_exp_avg_sq_.sqrt().pow_(hessian_power).add_(eps)
69
- num = exp_avg_ * debiased_step_size(step+1, beta1, beta2)
70
-
71
- return num.div_(denom)
72
-
73
-
74
- class AdaHessian(Module):
39
+ class AdaHessian(Transform):
75
40
  """AdaHessian: An Adaptive Second Order Optimizer for Machine Learning (https://arxiv.org/abs/2006.00719)
76
41
 
77
42
  This is similar to Adam, but the second momentum is replaced by square root of an exponential moving average of random hessian-vector products.
@@ -79,8 +44,6 @@ class AdaHessian(Module):
79
44
  Notes:
80
45
  - In most cases AdaHessian should be the first module in the chain because it relies on autograd. Use the ``inner`` argument if you wish to apply AdaHessian preconditioning to another module's output.
81
46
 
82
- - If you are using gradient estimators or reformulations, set ``hvp_method`` to "forward" or "central".
83
-
84
47
  - This module requires a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument (refer to documentation).
85
48
 
86
49
  Args:
@@ -97,17 +60,17 @@ class AdaHessian(Module):
97
60
  eps (float, optional):
98
61
  division stability epsilon. Defaults to 1e-8.
99
62
  hvp_method (str, optional):
100
- Determines how Hessian-vector products are evaluated.
101
-
102
- - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
103
- This requires creating a graph for the gradient.
104
- - ``"forward"``: Use a forward finite difference formula to
105
- approximate the HVP. This requires one extra gradient evaluation.
106
- - ``"central"``: Use a central finite difference formula for a
107
- more accurate HVP approximation. This requires two extra
108
- gradient evaluations.
109
- Defaults to "autograd".
110
- fd_h (float, optional): finite difference step size if ``hvp_method`` is "forward" or "central". Defaults to 1e-3.
63
+ Determines how hessian-vector products are computed.
64
+
65
+ - ``"batched_autograd"`` - uses autograd with batched hessian-vector products. If a single hessian-vector is evaluated, equivalent to ``"autograd"``. Faster than ``"autograd"`` but uses more memory.
66
+ - ``"autograd"`` - uses autograd hessian-vector products. If multiple hessian-vector products are evaluated, uses a for-loop. Slower than ``"batched_autograd"`` but uses less memory.
67
+ - ``"fd_forward"`` - uses gradient finite difference approximation with a less accurate forward formula which requires one extra gradient evaluation per hessian-vector product.
68
+ - ``"fd_central"`` - uses gradient finite difference approximation with a more accurate central formula which requires two gradient evaluations per hessian-vector product.
69
+
70
+ Defaults to ``"autograd"``.
71
+ h (float, optional):
72
+ The step size for finite difference if ``hvp_method`` is
73
+ ``"fd_forward"`` or ``"fd_central"``. Defaults to 1e-3.
111
74
  n_samples (int, optional):
112
75
  number of hessian-vector products with random vectors to evaluate each time when updating
113
76
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -151,74 +114,82 @@ class AdaHessian(Module):
151
114
  update_freq: int = 1,
152
115
  eps: float = 1e-8,
153
116
  hessian_power: float = 1,
154
- hvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
155
- fd_h: float = 1e-3,
117
+ distribution: Distributions = 'rademacher',
118
+ hvp_method: HVPMethod = 'autograd',
119
+ h: float = 1e-3,
156
120
  n_samples = 1,
121
+ zHz: bool = True,
122
+ debias: bool = True,
157
123
  seed: int | None = None,
158
- inner: Chainable | None = None
124
+
125
+ exp_avg_tfm: Chainable | None = None,
126
+ D_exp_avg_sq_tfm: Chainable | None = None,
159
127
  ):
160
- defaults = dict(beta1=beta1, beta2=beta2, update_freq=update_freq, averaging=averaging, block_size=block_size, eps=eps, hessian_power=hessian_power, hvp_method=hvp_method, n_samples=n_samples, fd_h=fd_h, seed=seed)
128
+ defaults = locals().copy()
129
+ del defaults['self'], defaults["exp_avg_tfm"], defaults["D_exp_avg_sq_tfm"]
161
130
  super().__init__(defaults)
162
131
 
163
- if inner is not None:
164
- self.set_child('inner', inner)
132
+ self.set_child('exp_avg', exp_avg_tfm)
133
+ self.set_child('D_exp_avg_sq', D_exp_avg_sq_tfm)
165
134
 
166
135
  @torch.no_grad
167
- def step(self, var):
168
- params = var.params
169
- settings = self.settings[params[0]]
170
- hvp_method = settings['hvp_method']
171
- fd_h = settings['fd_h']
172
- update_freq = settings['update_freq']
173
- n_samples = settings['n_samples']
136
+ def update_states(self, objective, states, settings):
137
+ params = objective.params
174
138
 
175
- seed = settings['seed']
176
- generator = self.get_generator(params[0].device, seed)
139
+ beta1, beta2, averaging, block_size = unpack_dicts(settings, 'beta1', 'beta2', 'averaging', 'block_size', cls=NumberList)
177
140
 
178
- beta1, beta2, eps, averaging, block_size, hessian_power = self.get_settings(params,
179
- 'beta1', 'beta2', 'eps', 'averaging', 'block_size', "hessian_power", cls=NumberList)
141
+ exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)
180
142
 
181
- exp_avg, D_exp_avg_sq = self.get_state(params, 'exp_avg', 'h_exp_avg', cls=TensorList)
143
+ # ---------------------------- hutchinson hessian ---------------------------- #
144
+ fs = settings[0]
145
+ step = self.increment_counter("step", start=0) # 0 on 1st update
146
+ update_freq = fs['update_freq']
182
147
 
183
- step = self.global_state.get('step', 0)
184
- self.global_state['step'] = step + 1
148
+ if step % update_freq == 0:
149
+ self.increment_counter("num_Ds", start=1)
150
+
151
+ D, _ = objective.hutchinson_hessian(
152
+ rgrad = None,
153
+ at_x0 = True,
154
+ n_samples = fs['n_samples'],
155
+ distribution = fs['distribution'],
156
+ hvp_method = fs['hvp_method'],
157
+ h = fs['h'],
158
+ zHz = fs["zHz"],
159
+ generator = self.get_generator(params[0].device, fs["seed"]),
160
+ )
185
161
 
186
- closure = var.closure
187
- assert closure is not None
162
+ D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
163
+ D_exp_avg_sq.mul_(beta2).addcmul_(D, D, value=1-beta2)
188
164
 
189
- D = None
190
- if step % update_freq == 0:
165
+ # --------------------------------- momentum --------------------------------- #
166
+ tensors = objective.get_updates() # do this after hutchinson to not disturb autograd
167
+ exp_avg.lerp_(tensors, 1-beta1)
191
168
 
192
- rgrad=None
193
- for i in range(n_samples):
194
- u = [_rademacher_like(p, generator=generator) for p in params]
195
169
 
196
- Hvp, rgrad = var.hessian_vector_product(u, at_x0=True, rgrad=rgrad, hvp_method=hvp_method,
197
- h=fd_h, normalize=True, retain_graph=i < n_samples-1)
198
- Hvp = tuple(Hvp)
170
+ @torch.no_grad
171
+ def apply_states(self, objective, states, settings):
172
+ params = objective.params
199
173
 
200
- if D is None: D = Hvp
201
- else: torch._foreach_add_(D, Hvp)
174
+ beta1, beta2, eps, hessian_power = unpack_dicts(settings, 'beta1', 'beta2', 'eps', 'hessian_power', cls=NumberList)
175
+ exp_avg, D_exp_avg_sq = unpack_states(states, params, 'exp_avg', 'D_exp_avg_sq', cls=TensorList)
202
176
 
203
- assert D is not None
204
- if n_samples > 1: torch._foreach_div_(D, n_samples)
177
+ # ---------------------------------- debias ---------------------------------- #
178
+ if settings[0]["debias"]:
179
+ bias_correction1 = 1.0 - (beta1 ** (self.global_state["step"] + 1))
180
+ bias_correction2 = 1.0 - (beta2 ** self.global_state["num_Ds"])
181
+ exp_avg = exp_avg / bias_correction1
182
+ D_exp_avg_sq = D_exp_avg_sq / bias_correction2
205
183
 
206
- D = TensorList(D).zipmap_args(_block_average, block_size, averaging)
207
184
 
208
- update = var.get_update()
209
- if 'inner' in self.children:
210
- update = apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var)
211
-
212
- var.update = adahessian(
213
- tensors=TensorList(update),
214
- D=TensorList(D) if D is not None else None,
215
- exp_avg_=exp_avg,
216
- D_exp_avg_sq_=D_exp_avg_sq,
217
- beta1=beta1,
218
- beta2=beta2,
219
- update_freq=update_freq,
220
- eps=eps,
221
- hessian_power=hessian_power,
222
- step=step,
223
- )
224
- return var
185
+ # -------------------------------- transforms -------------------------------- #
186
+ exp_avg = TensorList(self.inner_step_tensors(
187
+ "exp_avg", tensors=exp_avg, clone=True, objective=objective, must_exist=False))
188
+
189
+ D_exp_avg_sq = TensorList(self.inner_step_tensors(
190
+ "D_exp_avg_sq", tensors=D_exp_avg_sq, clone=True, objective=objective, must_exist=False))
191
+
192
+ # ------------------------------ compute update ------------------------------ #
193
+ denom = D_exp_avg_sq.lazy_pow(hessian_power / 2) + eps
194
+ objective.updates = exp_avg / denom
195
+ return objective
@@ -1,48 +1,11 @@
1
- from operator import itemgetter
2
- from functools import partial
3
-
4
1
  import torch
5
2
 
6
- from ...core import Module, Target, Transform, apply_transform, Chainable
3
+ from ...core import Chainable, Module, TensorTransform
7
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
- from ..functional import (
9
- debias, debiased_step_size,
10
- ema_,
11
- sqrt_ema_sq_,
12
- )
13
-
14
-
15
- def adam_(
16
- tensors: TensorList,
17
- exp_avg_: TensorList,
18
- exp_avg_sq_: TensorList,
19
- alpha: float | NumberList,
20
- beta1: float | NumberList,
21
- beta2: float | NumberList,
22
- eps: float | NumberList,
23
- step: int,
24
- pow: float = 2,
25
- debiased: bool = True,
26
- max_exp_avg_sq_: TensorList | None = None,
27
-
28
- # inner args
29
- inner: Module | None = None,
30
- params: list[torch.Tensor] | None = None,
31
- grads: list[torch.Tensor] | None = None,
32
- ):
33
- """Returns new tensors."""
34
- sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
35
- debiased=False,step=step,pow=pow)
36
-
37
- if inner is not None:
38
- assert params is not None
39
- tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
40
-
41
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
42
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
43
- return (exp_avg_.lazy_mul(alpha) / sqrt_exp_avg_sq.add_(eps))
44
-
45
- class Adam(Transform):
5
+ from ..functional import debiased_step_size
6
+
7
+
8
+ class Adam(TensorTransform):
46
9
  """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size.
47
10
 
48
11
  This implementation is identical to :code:`torch.optim.Adam`.
@@ -54,7 +17,7 @@ class Adam(Transform):
54
17
  alpha (float, optional): learning rate. Defaults to 1.
55
18
  amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
56
19
  pow (float, optional): power used in second momentum power and root. Defaults to 2.
57
- debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
20
+ debias (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
58
21
  """
59
22
  def __init__(
60
23
  self,
@@ -63,45 +26,59 @@ class Adam(Transform):
63
26
  eps: float = 1e-8,
64
27
  amsgrad: bool = False,
65
28
  alpha: float = 1.,
66
- pow: float = 2,
67
- debiased: bool = True,
68
- inner: Chainable | None = None
29
+ debias: bool = True,
30
+
31
+ exp_avg_tfm: Chainable | None = None,
32
+ exp_avg_sq_tfm: Chainable | None = None,
69
33
  ):
70
- defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
71
- super().__init__(defaults, uses_grad=False)
34
+ defaults = locals().copy()
35
+ del defaults['self'], defaults["exp_avg_tfm"], defaults["exp_avg_sq_tfm"]
36
+ super().__init__(defaults)
72
37
 
73
- if inner is not None: self.set_child('inner', inner)
38
+ self.set_child('exp_avg', exp_avg_tfm)
39
+ self.set_child('exp_avg_sq', exp_avg_sq_tfm)
74
40
 
75
41
  @torch.no_grad
76
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
77
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
78
-
79
- beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
80
- amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
81
-
82
- if amsgrad:
83
- exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
42
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
43
+ self.increment_counter("step", start=0)
44
+ beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
45
+
46
+ # ----------------------------- initialize states ---------------------------- #
47
+ if settings[0]["amsgrad"]:
48
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(
49
+ states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
84
50
  else:
85
51
  exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
86
52
  max_exp_avg_sq = None
87
53
 
54
+ # ------------------------------ update moments ------------------------------ #
55
+ exp_avg.lerp_(tensors, weight=1-beta1)
56
+ exp_avg_sq.mul_(beta2).addcmul_(tensors, tensors, value=1-beta2)
57
+
58
+ if max_exp_avg_sq is not None:
59
+ max_exp_avg_sq.maximum_(exp_avg_sq)
60
+
61
+ @torch.no_grad
62
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
63
+ step = self.global_state["step"] # 0 on 1st step
64
+ fs = settings[0]
65
+
66
+ if fs["amsgrad"]: key = "max_exp_avg_sq"
67
+ else: key = "exp_avg_sq"
68
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', key, cls=TensorList)
69
+ beta1, beta2, alpha, eps = unpack_dicts(settings, 'beta1', 'beta2', 'alpha', 'eps', cls=NumberList)
70
+
71
+ # -------------------------------- transforms -------------------------------- #
72
+ exp_avg = TensorList(self.inner_step_tensors(
73
+ "exp_avg", tensors=exp_avg, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
74
+
75
+ exp_avg_sq = TensorList(self.inner_step_tensors(
76
+ "exp_avg_sq", tensors=exp_avg_sq, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
77
+
78
+ # ---------------------------------- debias ---------------------------------- #
79
+ if fs["debias"]:
80
+ alpha = debiased_step_size((step + 1), beta1=beta1, beta2=beta2, alpha=alpha)
81
+ exp_avg = exp_avg * alpha
88
82
 
89
- return adam_(
90
- tensors=TensorList(tensors),
91
- exp_avg_=exp_avg,
92
- exp_avg_sq_=exp_avg_sq,
93
- alpha=alpha,
94
- beta1=beta1,
95
- beta2=beta2,
96
- eps=eps,
97
- step=step,
98
- pow=pow,
99
- debiased=debiased,
100
- max_exp_avg_sq_=max_exp_avg_sq,
101
-
102
- # inner args
103
- inner=self.children.get("inner", None),
104
- params=params,
105
- grads=grads,
106
-
107
- )
83
+ # ---------------------------------- update ---------------------------------- #
84
+ return exp_avg / exp_avg_sq.sqrt().add_(eps)
@@ -1,9 +1,9 @@
1
1
  import torch
2
2
 
3
- from ...core import Transform
3
+ from ...core import TensorTransform, Chainable
4
4
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
- def adan_(
6
+ def adan_update_(
7
7
  g: TensorList,
8
8
  g_prev_: TensorList,
9
9
  m_: TensorList, # exponential moving average
@@ -12,10 +12,8 @@ def adan_(
12
12
  beta1: float | NumberList,
13
13
  beta2: float | NumberList,
14
14
  beta3: float | NumberList,
15
- eps: float | NumberList,
16
15
  step: int,
17
16
  ):
18
- """Returns new tensors"""
19
17
  m_.lerp_(g, 1 - beta1)
20
18
 
21
19
  if step == 1:
@@ -26,7 +24,18 @@ def adan_(
26
24
  term = g + beta2 * diff
27
25
 
28
26
  n_.mul_(beta3).addcmul_(term, term, value=(1 - beta3))
27
+ g_prev_.copy_(g)
29
28
 
29
+ def adan_apply_(
30
+ m_: TensorList, # exponential moving average
31
+ v_: TensorList, # exponential moving average of gradient differences
32
+ n_: TensorList, # kinda like squared momentum
33
+ beta1: float | NumberList,
34
+ beta2: float | NumberList,
35
+ beta3: float | NumberList,
36
+ eps: float | NumberList,
37
+ step: int,
38
+ ):
30
39
  m = m_ / (1.0 - beta1**step)
31
40
  v = v_ / (1.0 - beta2**step)
32
41
  n = n_ / (1.0 - beta3**step)
@@ -35,13 +44,12 @@ def adan_(
35
44
  num = m + beta2 * v
36
45
 
37
46
  update = num.div_(denom)
38
- g_prev_.copy_(g)
39
47
 
40
48
  return update
41
49
 
42
50
 
43
51
 
44
- class Adan(Transform):
52
+ class Adan(TensorTransform):
45
53
  """Adaptive Nesterov Momentum Algorithm from https://arxiv.org/abs/2208.06677
46
54
 
47
55
  Args:
@@ -49,8 +57,6 @@ class Adan(Transform):
49
57
  beta2 (float, optional): momentum for gradient differences. Defaults to 0.92.
50
58
  beta3 (float, optional): thrid (squared) momentum. Defaults to 0.99.
51
59
  eps (float, optional): epsilon. Defaults to 1e-8.
52
- use_n_prev (bool, optional):
53
- whether to use previous gradient differences momentum.
54
60
 
55
61
  Example:
56
62
  ```python
@@ -59,8 +65,9 @@ class Adan(Transform):
59
65
  tz.m.Adan(),
60
66
  tz.m.LR(1e-3),
61
67
  )
68
+ ```
62
69
  Reference:
63
- Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence. https://arxiv.org/abs/2208.06677
70
+ [Xie, X., Zhou, P., Li, H., Lin, Z., & Yan, S. (2024). Adan: Adaptive nesterov momentum algorithm for faster optimizing deep models. IEEE Transactions on Pattern Analysis and Machine Intelligence](https://arxiv.org/abs/2208.06677).
64
71
  """
65
72
  def __init__(
66
73
  self,
@@ -68,29 +75,41 @@ class Adan(Transform):
68
75
  beta2: float = 0.92,
69
76
  beta3: float = 0.99,
70
77
  eps: float = 1e-8,
78
+
79
+ m_tfm: Chainable | None = None,
80
+ v_tfm: Chainable | None = None,
81
+ n_tfm: Chainable | None = None,
71
82
  ):
72
- defaults=dict(beta1=beta1,beta2=beta2,beta3=beta3,eps=eps)
83
+ defaults=dict(beta1=beta1, beta2=beta2, beta3=beta3, eps=eps)
73
84
  super().__init__(defaults, uses_grad=False)
74
85
 
86
+ self.set_child("m", m_tfm)
87
+ self.set_child("v", v_tfm)
88
+ self.set_child("n", n_tfm)
89
+
75
90
  @torch.no_grad
76
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
91
+ def multi_tensor_update(self, tensors, params, grads, loss, states, settings):
77
92
  tensors = TensorList(tensors)
78
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
79
-
80
- beta1,beta2,beta3,eps=unpack_dicts(settings, 'beta1','beta2','beta3','eps', cls=NumberList)
81
- g_prev, m, v, n = unpack_states(states, tensors, 'g_prev','m','v','n', cls=TensorList)
82
-
83
- update = adan_(
84
- g=tensors,
85
- g_prev_=g_prev,
86
- m_=m,
87
- v_=v,
88
- n_=n,
89
- beta1=beta1,
90
- beta2=beta2,
91
- beta3=beta3,
92
- eps=eps,
93
- step=step,
94
- )
95
-
96
- return update
93
+ step = self.increment_counter("step", start=0)
94
+
95
+ beta1, beta2, beta3 = unpack_dicts(settings, 'beta1','beta2','beta3', cls=NumberList)
96
+ g_prev, m, v, n = unpack_states(states, tensors, 'g_prev', 'm', 'v', 'n', cls=TensorList)
97
+
98
+ adan_update_(g=tensors, g_prev_=g_prev, m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, step=step+1)
99
+
100
+ @torch.no_grad
101
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
102
+ tensors = TensorList(tensors)
103
+ step = self.global_state["step"] # 0 on 1st step
104
+
105
+ beta1, beta2, beta3, eps = unpack_dicts(settings, 'beta1','beta2','beta3', 'eps', cls=NumberList)
106
+ m, v, n = unpack_states(states, tensors, 'm', 'v', 'n')
107
+
108
+ # -------------------------------- transforms -------------------------------- #
109
+ m = TensorList(self.inner_step_tensors("m", m, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
110
+ v = TensorList(self.inner_step_tensors("v", v, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
111
+ n = TensorList(self.inner_step_tensors("n", n, clone=True, params=params, grads=grads, loss=loss, must_exist=False))
112
+
113
+ # ---------------------------------- update ---------------------------------- #
114
+ return adan_apply_(m_=m, v_=v, n_=n, beta1=beta1, beta2=beta2, beta3=beta3, eps=eps, step=step+1)
115
+
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from ...core import Transform
2
+ from ...core import TensorTransform
3
3
  from ...utils import TensorList, unpack_dicts, unpack_states
4
4
 
5
5
 
@@ -16,10 +16,10 @@ def adaptive_heavy_ball(f, f_star, f_prev, g: TensorList, g_prev: TensorList, p:
16
16
  return (1 + m) * h * g - m*(p-p_prev)
17
17
 
18
18
 
19
- class AdaptiveHeavyBall(Transform):
19
+ class AdaptiveHeavyBall(TensorTransform):
20
20
  """Adaptive heavy ball from https://hal.science/hal-04832983v1/file/OJMO_2024__5__A7_0.pdf.
21
21
 
22
- This is related to conjugate gradient methods, it may be very good for non-stochastic convex objectives, but won't work on stochastic ones.
22
+ Suitable for quadratic objectives with known f* (loss at minimum).
23
23
 
24
24
  note:
25
25
  The step size is determined by the algorithm, so learning rate modules shouldn't be used.
@@ -33,22 +33,27 @@ class AdaptiveHeavyBall(Transform):
33
33
  super().__init__(defaults, uses_grad=False, uses_loss=True)
34
34
 
35
35
  @torch.no_grad
36
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
36
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
37
37
  assert loss is not None
38
38
  tensors = TensorList(tensors)
39
- f_star = self.defaults['f_star']
39
+ f_star = settings[0]['f_star']
40
40
 
41
41
  f_prev = self.global_state.get('f_prev', None)
42
42
  p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', init=[params,tensors], cls=TensorList)
43
43
 
44
+ # -------------------------------- first step -------------------------------- #
44
45
  if f_prev is None:
45
46
  self.global_state['f_prev'] = loss
46
47
  h = 2*(loss - f_star) / tensors.dot(tensors)
47
48
  return h * tensors
48
49
 
49
- update = adaptive_heavy_ball(f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
50
+ # ------------------------------- further steps ------------------------------ #
51
+ update = adaptive_heavy_ball(
52
+ f=loss, f_star=f_star, f_prev=f_prev, g=tensors, g_prev=g_prev, p=TensorList(params), p_prev=p_prev)
50
53
 
54
+ # --------------------------- store previous values -------------------------- #
51
55
  self.global_state['f_prev'] = loss
52
56
  p_prev.copy_(params)
53
57
  g_prev.copy_(tensors)
58
+
54
59
  return update
@@ -2,17 +2,18 @@ import math
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Transform
5
+ from ...core import TensorTransform
6
6
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
7
 
8
8
  # i've verified, it is identical to official
9
9
  # https://github.com/txping/AEGD/blob/master/aegd.py
10
+ # TODO: add a test
10
11
  def aegd_(f: torch.Tensor | float, g: TensorList, r_: TensorList, c:float|NumberList=1, eta:float|NumberList=0.1) -> TensorList:
11
12
  v = g / (2 * (f + c)**0.5)
12
13
  r_ /= 1 + (v ** 2).mul_(2*eta) # update energy
13
14
  return 2*eta * r_*v # pyright:ignore[reportReturnType]
14
15
 
15
- class AEGD(Transform):
16
+ class AEGD(TensorTransform):
16
17
  """AEGD (Adaptive gradient descent with energy) from https://arxiv.org/abs/2010.05109#page=10.26.
17
18
 
18
19
  Note:
@@ -20,28 +21,27 @@ class AEGD(Transform):
20
21
  To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
21
22
 
22
23
  Args:
23
- eta (float, optional): step size. Defaults to 0.1.
24
- c (float, optional): c. Defaults to 1.
25
- beta3 (float, optional): thrid (squared) momentum. Defaults to 0.1.
26
- eps (float, optional): epsilon. Defaults to 1e-8.
27
- use_n_prev (bool, optional):
28
- whether to use previous gradient differences momentum.
24
+ lr (float, optional): learning rate (default: 0.1)
25
+ c (float, optional): term added to the original objective function (default: 1)
26
+
27
+ Reference:
28
+ [Liu, Hailiang, and Xuping Tian. "AEGD: Adaptive gradient descent with energy." arXiv preprint arXiv:2010.05109 (2020).](https://arxiv.org/pdf/2010.05109)
29
29
  """
30
30
  def __init__(
31
31
  self,
32
32
  lr: float = 0.1,
33
33
  c: float = 1,
34
34
  ):
35
- defaults=dict(c=c,lr=lr)
35
+ defaults = dict(c=c, lr=lr)
36
36
  super().__init__(defaults, uses_loss=True)
37
37
 
38
38
  @torch.no_grad
39
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
39
+ def multi_tensor_apply(self, tensors, params, grads, loss, states, settings):
40
40
  assert loss is not None
41
41
  tensors = TensorList(tensors)
42
42
 
43
- c,lr=unpack_dicts(settings, 'c','lr', cls=NumberList)
44
- r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss+c[0])**0.5), cls=TensorList)
43
+ c, lr = unpack_dicts(settings, 'c', 'lr', cls=NumberList)
44
+ r = unpack_states(states, tensors, 'r', init=lambda t: torch.full_like(t, float(loss + c[0])**0.5), cls=TensorList)
45
45
 
46
46
  update = aegd_(
47
47
  f=loss,