torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -1,49 +1,146 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
- from ...core import OptimizerModule
7
-
8
- def _adagrad_step_(ascent: TensorList, grad_sum: TensorList, alpha: TensorList, lr_decay: TensorList, eps: TensorList, step: int):
9
- clr = alpha / (1 + step * lr_decay)
10
- grad_sum.addcmul_(ascent, ascent)
11
- return ascent.div_(grad_sum.sqrt().add_(eps)).mul_(clr)
12
-
13
- class Adagrad(OptimizerModule):
14
- """
15
- Divides ascent direction by mean square root of the sum of all past ascent directions.
16
-
17
- Exactly matches `torch.optim.Adagrad`.
18
-
19
- Args:
20
- lr_decay (float, optional): learning rate decay. Defaults to 0.
21
- initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
22
- eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-10.
23
- alpha (float, optional): learning rate. Defaults to 1.
24
-
25
- reference
26
- https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
27
- """
28
- def __init__(self, lr_decay: float = 0, initial_accumulator_value: float = 0, eps: float = 1e-10, alpha: float = 1):
29
- defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value, eps = eps)
30
- super().__init__(defaults)
31
- self.cur_step = 0
32
-
33
- @torch.no_grad
34
- def _update(self, vars, ascent):
35
- settings = self.get_all_group_keys()
36
- if self.cur_step == 0: init = ascent.full_like(settings['initial_accumulator_value'])
37
- else: init = None
38
- grad_sum = self.get_state_key('grad_sum', init = init) # type:ignore
39
-
40
- updated_direction = _adagrad_step_(
41
- ascent=ascent,
42
- grad_sum=grad_sum,
43
- alpha=settings["alpha"],
44
- eps=settings["eps"],
45
- lr_decay=settings["lr_decay"],
46
- step=self.cur_step,
47
- )
48
- self.cur_step += 1
49
- return updated_direction
1
+ from operator import itemgetter
2
+
3
+ import torch
4
+
5
+ from ...core import (
6
+ Chainable,
7
+ Module,
8
+ Preconditioner,
9
+ Target,
10
+ TensorwisePreconditioner,
11
+ Transform,
12
+ Vars,
13
+ apply,
14
+ )
15
+ from ...utils import NumberList, TensorList
16
+ from ...utils.linalg import matrix_power_eigh
17
+ from ..functional import add_power_, lerp_power_, root
18
+
19
+
20
+ def adagrad_(
21
+ tensors_: TensorList,
22
+ sq_sum_: TensorList,
23
+ alpha: float | NumberList,
24
+ lr_decay: float | NumberList,
25
+ eps: float | NumberList,
26
+ step: int,
27
+ pow: float = 2,
28
+ use_sqrt: bool = True,
29
+
30
+ # inner args
31
+ inner: Module | None = None,
32
+ params: list[torch.Tensor] | None = None,
33
+ grads: list[torch.Tensor] | None = None,
34
+ vars: Vars | None = None,
35
+ ):
36
+ """returns `tensors_`"""
37
+ clr = alpha / (1 + step * lr_decay)
38
+
39
+ sq_sum_ = add_power_(tensors_, sum_=sq_sum_, pow=pow)
40
+
41
+ if inner is not None:
42
+ assert params is not None
43
+ tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
44
+
45
+ if use_sqrt: tensors_.div_(root(sq_sum_, p=pow, inplace=False).add_(eps)).mul_(clr)
46
+ else: tensors_.div_(sq_sum_.add(eps)).mul_(clr)
47
+
48
+ return tensors_
49
+
50
+
51
+
52
+ class Adagrad(Transform):
53
+ """Adagrad, divides by sum of past squares of gradients, matches pytorch Adagrad.
54
+
55
+ Args:
56
+ lr_decay (float, optional): learning rate decay. Defaults to 0.
57
+ initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
58
+ eps (float, optional): division epsilon. Defaults to 1e-10.
59
+ alpha (float, optional): step size. Defaults to 1.
60
+ pow (float, optional): power for gradients and accumulator root. Defaults to 2.
61
+ use_sqrt (bool, optional): whether to take the root of the accumulator. Defaults to True.
62
+ inner (Chainable | None, optional): Inner modules that are applied after updating accumulator and before preconditioning. Defaults to None.
63
+ """
64
+ def __init__(
65
+ self,
66
+ lr_decay: float = 0,
67
+ initial_accumulator_value: float = 0,
68
+ eps: float = 1e-10,
69
+ alpha: float = 1,
70
+ pow: float = 2,
71
+ use_sqrt: bool = True,
72
+ inner: Chainable | None = None,
73
+ ):
74
+ defaults = dict(alpha = alpha, lr_decay = lr_decay, initial_accumulator_value=initial_accumulator_value,
75
+ eps = eps, pow=pow, use_sqrt = use_sqrt)
76
+ super().__init__(defaults=defaults, uses_grad=False)
77
+
78
+ if inner is not None:
79
+ self.set_child('inner', inner)
80
+
81
+ @torch.no_grad
82
+ def transform(self, tensors, params, grads, vars):
83
+ tensors = TensorList(tensors)
84
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
85
+
86
+ lr_decay,alpha,eps = self.get_settings('lr_decay', 'alpha', 'eps', params=params, cls=NumberList)
87
+
88
+ pow, use_sqrt = itemgetter('pow', 'use_sqrt')(self.settings[params[0]])
89
+
90
+ sq_sum = self.get_state('sq_sum', params=params, cls=TensorList)
91
+
92
+ # initialize accumulator on 1st step
93
+ if step == 1:
94
+ sq_sum.set_(tensors.full_like(self.get_settings('initial_accumulator_value', params=params)))
95
+
96
+ return adagrad_(
97
+ tensors,
98
+ sq_sum_=sq_sum,
99
+ alpha=alpha,
100
+ lr_decay=lr_decay,
101
+ eps=eps,
102
+ step=self.global_state["step"],
103
+ pow=pow,
104
+ use_sqrt=use_sqrt,
105
+
106
+ # inner args
107
+ inner=self.children.get("inner", None),
108
+ params=params,
109
+ grads=grads,
110
+ vars=vars,
111
+ )
112
+
113
+
114
+
115
+ class FullMatrixAdagrad(TensorwisePreconditioner):
116
+ def __init__(self, beta: float | None = None, decay: float | None = None, concat_params=False, update_freq=1, inner: Chainable | None = None):
117
+ defaults = dict(beta=beta, decay=decay)
118
+ super().__init__(defaults, uses_grad=False, concat_params=concat_params, update_freq=update_freq, inner=inner)
119
+
120
+ @torch.no_grad
121
+ def update_tensor(self, tensor, param, grad, state, settings):
122
+ G = tensor.ravel()
123
+ GG = torch.outer(G, G)
124
+ decay = settings['decay']
125
+ beta = settings['beta']
126
+
127
+ if 'GG' not in state: state['GG'] = torch.eye(GG.size(0), device=GG.device, dtype=GG.dtype)
128
+ if decay is not None: state['GG'].mul_(decay)
129
+
130
+ if beta is not None: state['GG'].lerp_(GG, 1-beta)
131
+ else: state['GG'].add_(GG)
132
+
133
+ @torch.no_grad
134
+ def apply_tensor(self, tensor, param, grad, state, settings):
135
+ GG = state['GG']
136
+
137
+ if tensor.numel() == 1:
138
+ return tensor / (GG**(1/2)).squeeze()
139
+
140
+ try:
141
+ B = matrix_power_eigh(GG, -1/2)
142
+ except torch.linalg.LinAlgError:
143
+ return tensor.div_(tensor.abs().max()) # conservative scaling
144
+
145
+ return (B @ tensor.ravel()).view_as(tensor)
146
+
@@ -1,118 +1,112 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
- from ...core import OptimizerModule
7
-
8
- def _adam_step(ascent: TensorList, exp_avg: TensorList, exp_avg_sq: TensorList, alpha, beta1, beta2, eps, step:int, max_exp_avg_sqs: TensorList | None, params: TensorList | None = None):
9
- # Decay the first and second moment running average coefficient
10
- exp_avg.lerp_compat_(ascent, 1 - beta1)
11
- exp_avg_sq.mul_(beta2).addcmul_(ascent, ascent.conj(), value=1 - beta2)
12
-
13
- bias_correction1 = 1 - beta1**step
14
- bias_correction2 = 1 - beta2**step
15
-
16
- if max_exp_avg_sqs is not None:
17
- max_exp_avg_sqs.maximum_(exp_avg_sq)
18
- denom = max_exp_avg_sqs.sqrt().div_(bias_correction2**0.5).add_(eps)
19
- else:
20
- denom = exp_avg_sq.sqrt().div_(bias_correction2**0.5).add_(eps)
21
-
22
- if params is None:
23
- return (exp_avg / denom).mul_(alpha / bias_correction1)
24
-
25
- # else directly apply the update to params
26
- params.addcdiv_(exp_avg, denom, value = -(alpha / bias_correction1))
27
- return None
28
-
29
-
30
-
31
- class Adam(OptimizerModule):
32
- """Adam. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`.
33
-
34
- Args:
35
- beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
36
- beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
37
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
38
- amsgrad (bool, optional):
39
- whether to use the AMSGrad variant of this algorithm from
40
- On the Convergence of Adam and Beyond (default: False).
41
- alpha (float, optional): learning rate. Defaults to 1.
42
- """
43
- def __init__(self, beta1: float = 0.9, beta2: float = 0.999, eps: float = 1e-8, alpha: float = 1, amsgrad=False):
44
- defaults = dict(alpha = alpha, beta1=beta1, beta2=beta2, eps=eps)
45
- super().__init__(defaults)
46
-
47
- self.cur_step = 1
48
- self.amsgrad = amsgrad
49
-
50
- @torch.no_grad
51
- def step(self, vars):
52
- # Adam step is a bit differet from other optimizer steps
53
- # due to how common it is, I implemented two additional optimizations,
54
-
55
- # 1st - if next module is None or if next module is LR and module after is None
56
- # this will directly update parameters using `addcdiv_`
57
-
58
- # 2nd - if next module is LR`, adam will "fuse" with it to avoid an additional add operation.
59
-
60
- # the optimizations are quite verbose and seem to barely have any effect, so I probably won't implement
61
- # this for other modules
62
-
63
- settings = self.get_all_group_keys()
64
-
65
- if self.amsgrad:
66
- exp_avg, exp_avg_sq, max_exp_avg_sqs = self.get_state_keys('exp_avg', 'exp_avg_sq', 'max_exp_avg_sqs')
67
- else:
68
- exp_avg, exp_avg_sq = self.get_state_keys('exp_avg', 'exp_avg_sq')
69
- max_exp_avg_sqs = None
70
-
71
- params = None
72
-
73
- # apply addcdiv if next module is None
74
- if self.next_module is None: params = self.get_params()
75
-
76
- # fuse with LR module if it is next
77
- if self.next_module is not None and self.next_module.IS_LR_MODULE:
78
- alpha = self.next_module.get_group_key('lr') * settings['alpha']
79
- self.next_module._skip = True # type:ignore
80
-
81
- # apply addcdiv if module after LR is None.
82
- if self.next_module.next_module is None: params = self.get_params()
83
-
84
- else:
85
- alpha = settings['alpha']
86
-
87
- # get params if ascent is None so we need params to access their gradient as initial ascent
88
- if vars.ascent is None:
89
- if params is None: pg = self.get_params()
90
- else: pg = params
91
- else:
92
- pg = None
93
-
94
- ret = _adam_step(
95
- ascent=vars.maybe_use_grad_(pg),
96
- exp_avg = exp_avg,
97
- exp_avg_sq = exp_avg_sq,
98
- alpha = alpha,
99
- beta1 = settings['beta1'],
100
- beta2 = settings['beta2'],
101
- eps = settings['eps'],
102
- step = self.cur_step,
103
- max_exp_avg_sqs = max_exp_avg_sqs,
104
- params = params
105
- )
106
-
107
- self.cur_step += 1
108
- if params is None:
109
- assert ret is not None
110
- vars.ascent = ret
111
- return self._update_params_or_step_with_next(vars)
112
-
113
- # next module is either None or LR
114
- if self.next_module is None: return vars.get_loss()
115
-
116
- # step with LR, which has _skip = True so it won't apply lr, but may step with the scheduler
117
- self.next_module._update(vars, None) # type:ignore
118
- return vars.get_loss()
1
+ from operator import itemgetter
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform
7
+ from ...utils import NumberList, TensorList
8
+ from ..functional import (
9
+ debias, debiased_step_size,
10
+ ema_,
11
+ sqrt_ema_sq_,
12
+ )
13
+ from ..lr.lr import lazy_lr
14
+ from ..momentum.experimental import sqrt_nag_ema_sq_
15
+ from ..momentum.momentum import nag_
16
+
17
+
18
+ def adam_(
19
+ tensors: TensorList,
20
+ exp_avg_: TensorList,
21
+ exp_avg_sq_: TensorList,
22
+ alpha: float | NumberList,
23
+ beta1: float | NumberList,
24
+ beta2: float | NumberList,
25
+ eps: float | NumberList,
26
+ step: int,
27
+ pow: float = 2,
28
+ debiased: bool = True,
29
+ max_exp_avg_sq_: TensorList | None = None,
30
+ params_: TensorList | None = None,
31
+ ):
32
+ """Returns new tensors or updates params in-place."""
33
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
34
+
35
+ sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
36
+ debiased=False,step=step,pow=pow)
37
+
38
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
39
+
40
+ # params is None, return update
41
+ if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
42
+
43
+ # update params in-place
44
+ params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
45
+ return None
46
+
47
+ class Adam(Module):
48
+ """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
49
+ pytorch in that debiasing is applied after adding epsilon.
50
+
51
+ Args:
52
+ beta1 (float, optional): momentum. Defaults to 0.9.
53
+ beta2 (float, optional): second momentum. Defaults to 0.999.
54
+ eps (float, optional): epsilon. Defaults to 1e-8.
55
+ alpha (float, optional): learning rate. Defaults to 1.
56
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
57
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
58
+ debiased (bool, optional): whether to apply debiasing to momentums based on current step. Defaults to True.
59
+ """
60
+ def __init__(
61
+ self,
62
+ beta1: float = 0.9,
63
+ beta2: float = 0.999,
64
+ eps: float = 1e-8,
65
+ amsgrad: bool = False,
66
+ alpha: float = 1.,
67
+ pow: float = 2,
68
+ debiased: bool = True,
69
+ ):
70
+ defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
71
+ super().__init__(defaults)
72
+ self.getter = itemgetter('amsgrad','pow','debiased')
73
+
74
+ @torch.no_grad
75
+ def step(self, vars):
76
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
+
78
+ beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
79
+ amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
80
+
81
+ if amsgrad:
82
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
83
+ else:
84
+ exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
85
+ max_exp_avg_sq = None
86
+
87
+ # if this is last module, update parameters in-place with slightly more efficient addcdiv_
88
+ if vars.is_last:
89
+ if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
90
+ passed_params = TensorList(vars.params)
91
+ vars.stop = True
92
+ vars.skip_update = True
93
+
94
+ else:
95
+ passed_params = None
96
+
97
+ vars.update = adam_(
98
+ tensors=TensorList(vars.get_update()),
99
+ exp_avg_=exp_avg,
100
+ exp_avg_sq_=exp_avg_sq,
101
+ alpha=alpha,
102
+ beta1=beta1,
103
+ beta2=beta2,
104
+ eps=eps,
105
+ step=step,
106
+ pow=pow,
107
+ debiased=debiased,
108
+ max_exp_avg_sq_=max_exp_avg_sq,
109
+ params_=passed_params,
110
+ )
111
+
112
+ return vars
@@ -1,15 +1,21 @@
1
1
  import torch
2
2
 
3
- from ...core import OptimizerModule
4
- from ...tensorlist import TensorList
3
+ from ...core import Module, Target, Transform
4
+ from ...utils import NumberList, TensorList
5
5
 
6
6
 
7
- def _lion_step_(ascent: TensorList, ema: TensorList, beta1, beta2,):
8
- update = ema.lerp_compat(ascent, 1-beta1).sign_()
9
- ema.lerp_compat_(ascent, 1-beta2)
7
+ def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
8
+ """
9
+ Lion update rule.
10
+
11
+ Returns new tensors.
12
+ """
13
+ update = exp_avg_.lerp(tensors, 1-beta1).sign_()
14
+ exp_avg_.lerp_(tensors, 1-beta2)
10
15
  return update
11
16
 
12
- class Lion(OptimizerModule):
17
+
18
+ class Lion(Transform):
13
19
  """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
14
20
 
15
21
  Args:
@@ -19,10 +25,11 @@ class Lion(OptimizerModule):
19
25
 
20
26
  def __init__(self, beta1: float = 0.9, beta2: float = 0.99):
21
27
  defaults = dict(beta1=beta1, beta2=beta2)
22
- super().__init__(defaults)
28
+ super().__init__(defaults, uses_grad=False)
23
29
 
24
30
  @torch.no_grad
25
- def _update(self, vars, ascent):
26
- beta1, beta2 = self.get_group_keys('beta1', 'beta2')
27
- ema = self.get_state_key('ema')
28
- return _lion_step_(ascent,ema,beta1,beta2)
31
+ def transform(self, tensors, params, grads, vars):
32
+ beta1, beta2 = self.get_settings('beta1', 'beta2', params = params, cls=NumberList)
33
+ exp_avg = self.get_state('ema', params=params, cls=TensorList)
34
+ return lion_(TensorList(tensors),exp_avg,beta1,beta2)
35
+