torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,4 +1,14 @@
1
- """
2
- Modules that implement momentum.
3
- """
4
- from .momentum import HeavyBall, NesterovMomentum, RandomCoordinateMomentum, GradientAveraging
1
+ from .averaging import Averaging, MedianAveraging, WeightedAveraging
2
+ from .cautious import (
3
+ Cautious,
4
+ IntermoduleCautious,
5
+ ScaleByGradCosineSimilarity,
6
+ ScaleModulesByCosineSimilarity,
7
+ UpdateGradientSignConsistency,
8
+ )
9
+ from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
10
+ from .experimental import CoordinateMomentum
11
+ # from .matrix_momentum import MatrixMomentum
12
+
13
+ from .momentum import NAG, HeavyBall
14
+ from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
@@ -0,0 +1,78 @@
1
+ from collections import deque
2
+ from collections.abc import Sequence
3
+ from typing import Any, Literal, cast
4
+
5
+ import torch
6
+
7
+ from ...core import TensorwiseTransform, Target
8
+ from ...utils import tolist
9
+
10
+
11
+ class Averaging(TensorwiseTransform):
12
+ def __init__(self, history_size: int, target: Target = 'update'):
13
+ defaults = dict(history_size=history_size)
14
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
15
+
16
+ @torch.no_grad
17
+ def transform(self, tensor, param, grad, vars):
18
+ history_size = self.settings[param]['history_size']
19
+ state = self.state[param]
20
+ if 'history' not in state:
21
+ state['history'] = deque(maxlen=history_size)
22
+ state['average'] = torch.zeros_like(tensor)
23
+
24
+ history = state['history']; average = state['average']
25
+ if len(history) == history_size: average -= history[0]
26
+ history.append(tensor)
27
+ average += tensor
28
+
29
+ return average / len(history)
30
+
31
+ class WeightedAveraging(TensorwiseTransform):
32
+ """weights are oldest to newest"""
33
+ def __init__(self, weights: Sequence[float] | torch.Tensor | Any, target: Target = 'update'):
34
+ defaults = dict(weights = tolist(weights))
35
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
36
+
37
+ @torch.no_grad
38
+ def transform(self, tensor, param, grad, vars):
39
+ weights = self.settings[param]['weights']
40
+ state = self.state[param]
41
+
42
+ if 'history' not in state:
43
+ state['history'] = deque(maxlen=len(weights))
44
+
45
+ history = state['history']
46
+ history.append(tensor)
47
+ if len(history) != len(weights):
48
+ weights = weights[-len(history):]
49
+
50
+ average = None
51
+ for i, (h, w) in enumerate(zip(history, weights)):
52
+ if average is None: average = h * (w / len(history))
53
+ else:
54
+ if w == 0: continue
55
+ average += h * (w / len(history))
56
+
57
+ assert average is not None
58
+ return average
59
+
60
+
61
+ class MedianAveraging(TensorwiseTransform):
62
+ def __init__(self, history_size: int, target: Target = 'update'):
63
+ defaults = dict(history_size = history_size)
64
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
65
+
66
+ @torch.no_grad
67
+ def transform(self, tensor, param, grad, vars):
68
+ history_size = self.settings[param]['history_size']
69
+ state = self.state[param]
70
+
71
+ if 'history' not in state:
72
+ state['history'] = deque(maxlen=history_size)
73
+
74
+ history = state['history']
75
+ history.append(tensor)
76
+
77
+ stacked = torch.stack(tuple(history), 0)
78
+ return torch.quantile(stacked, 0.5, dim = 0)
@@ -0,0 +1,181 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Target, Transform, Module, Chainable
8
+ from ...utils import NumberList, TensorList
9
+
10
+
11
+ def cautious_(
12
+ tensors_: TensorList,
13
+ grads: TensorList,
14
+ normalize: bool,
15
+ eps: float,
16
+ mode: Literal['zero', 'grad', 'backtrack']
17
+ ):
18
+ # mask will be > 0 for parameters where both signs are the same
19
+ mask = (tensors_ * grads) > 0
20
+ if mode in ('zero', 'grad'):
21
+ if normalize and mode == 'zero':
22
+ fmask = mask.to(tensors_[0].dtype)
23
+ fmask /= fmask.global_mean().clip(min=eps) # type:ignore
24
+ else:
25
+ fmask = mask
26
+
27
+ tensors_ *= fmask
28
+
29
+ if mode == 'grad':
30
+ tensors_ += grads * mask.logical_not_()
31
+
32
+ return tensors_
33
+
34
+ # mode = 'backtrack'
35
+ tensors_ -= tensors_.mul(2).mul_(mask.logical_not_())
36
+ return tensors_
37
+
38
+ class Cautious(Transform):
39
+ """Negates update for parameters where update and gradient sign is inconsistent.
40
+ Optionally normalizes the update by the number of parameters that are not masked.
41
+ This is meant to be used after any momentum-based modules.
42
+
43
+ Args:
44
+ normalize (bool, optional):
45
+ renormalize update after masking.
46
+ only has effect when mode is 'zero'. Defaults to False.
47
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
48
+ mode (str, optional):
49
+ what to do with updates with inconsistent signs.
50
+
51
+ "zero" - set them to zero (as in paper)
52
+
53
+ "grad" - set them to the gradient
54
+
55
+ "backtrack" - negate them (same as using update magnitude and gradient sign)
56
+
57
+ reference
58
+ *Cautious Optimizers: Improving Training with One Line of Code.
59
+ Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ normalize=False,
65
+ eps=1e-6,
66
+ mode: Literal["zero", "grad", "backtrack"] = "zero",
67
+ target: Target = "update",
68
+ ):
69
+ defaults = dict(normalize=normalize, eps=eps, mode=mode)
70
+ super().__init__(defaults, uses_grad=True, target=target)
71
+
72
+ @torch.no_grad
73
+ def transform(self, tensors, params, grads, vars):
74
+ assert grads is not None
75
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[params[0]])
76
+ return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
77
+
78
+ class UpdateGradientSignConsistency(Transform):
79
+ """1 where signs match 0 otherwise"""
80
+ def __init__(self, normalize = False, eps=1e-6, target: Target = 'update'):
81
+ defaults = dict(normalize=normalize, eps=eps)
82
+ super().__init__(defaults, uses_grad=True, target=target)
83
+
84
+ @torch.no_grad
85
+ def transform(self, tensors, params, grads, vars):
86
+ assert grads is not None
87
+ normalize, eps = itemgetter('normalize', 'eps')(self.settings[params[0]])
88
+
89
+ mask = (TensorList(tensors).mul_(grads)).gt_(0)
90
+ if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
91
+
92
+ return mask
93
+
94
+ class IntermoduleCautious(Module):
95
+ def __init__(
96
+ self,
97
+ main: Chainable,
98
+ compare: Chainable,
99
+ normalize=False,
100
+ eps=1e-6,
101
+ mode: Literal["zero", "grad", "backtrack"] = "zero",
102
+ ):
103
+ defaults = dict(normalize=normalize, eps=eps, mode=mode)
104
+ super().__init__(defaults)
105
+
106
+ self.set_child('main', main)
107
+ self.set_child('compare', compare)
108
+
109
+ @torch.no_grad
110
+ def step(self, vars):
111
+ main = self.children['main']
112
+ compare = self.children['compare']
113
+
114
+ main_vars = main.step(vars.clone(clone_update=True))
115
+ vars.update_attrs_from_clone_(main_vars)
116
+
117
+ compare_vars = compare.step(vars.clone(clone_update=True))
118
+ vars.update_attrs_from_clone_(compare_vars)
119
+
120
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[vars.params[0]])
121
+ vars.update = cautious_(
122
+ TensorList(main_vars.get_update()),
123
+ TensorList(compare_vars.get_update()),
124
+ normalize=normalize,
125
+ mode=mode,
126
+ eps=eps,
127
+ )
128
+
129
+ return vars
130
+
131
+ class ScaleByGradCosineSimilarity(Transform):
132
+ def __init__(
133
+ self,
134
+ eps=1e-6,
135
+ target: Target = "update",
136
+ ):
137
+ defaults = dict(eps=eps)
138
+ super().__init__(defaults, uses_grad=True, target=target)
139
+
140
+ @torch.no_grad
141
+ def transform(self, tensors, params, grads, vars):
142
+ assert grads is not None
143
+ eps = self.settings[params[0]]['eps']
144
+ tensors = TensorList(tensors)
145
+ grads = TensorList(grads)
146
+ cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
147
+
148
+ return tensors.mul_(cos_sim)
149
+
150
+ class ScaleModulesByCosineSimilarity(Module):
151
+ def __init__(
152
+ self,
153
+ main: Chainable,
154
+ compare: Chainable,
155
+ eps=1e-6,
156
+ ):
157
+ defaults = dict(eps=eps)
158
+ super().__init__(defaults)
159
+
160
+ self.set_child('main', main)
161
+ self.set_child('compare', compare)
162
+
163
+ @torch.no_grad
164
+ def step(self, vars):
165
+ main = self.children['main']
166
+ compare = self.children['compare']
167
+
168
+ main_vars = main.step(vars.clone(clone_update=True))
169
+ vars.update_attrs_from_clone_(main_vars)
170
+
171
+ compare_vars = compare.step(vars.clone(clone_update=True))
172
+ vars.update_attrs_from_clone_(compare_vars)
173
+
174
+ m = TensorList(main_vars.get_update())
175
+ c = TensorList(compare_vars.get_update())
176
+ eps = self.settings[vars.params[0]]['eps']
177
+
178
+ cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
179
+
180
+ vars.update = m.mul_(cos_sim)
181
+ return vars
@@ -0,0 +1,173 @@
1
+ from collections import deque
2
+ from operator import itemgetter
3
+ from typing import Literal
4
+
5
+ import torch
6
+
7
+ from ...core import Target, Transform
8
+ from ...utils import TensorList, NumberList
9
+ from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
10
+
11
+
12
+ class EMA(Transform):
13
+ """Maintains EMA of update.
14
+
15
+ Args:
16
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
17
+ dampening (float, optional): momentum dampening. Defaults to 0.
18
+ debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
+ lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
21
+ """
22
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
23
+ defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
24
+ super().__init__(defaults, uses_grad=False, target=target)
25
+
26
+ @torch.no_grad
27
+ def transform(self, tensors, params, grads, vars):
28
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
+
30
+ debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(self.settings[params[0]])
31
+
32
+ exp_avg = self.get_state('exp_avg', params=params, init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
33
+ momentum, dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
34
+
35
+ exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
36
+
37
+ if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
38
+ else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
39
+
40
+
41
+ class EMASquared(Transform):
42
+ EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
43
+
44
+ def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2, target: Target = 'update'):
45
+ defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
46
+ super().__init__(defaults, uses_grad=False, target=target)
47
+
48
+ @torch.no_grad
49
+ def transform(self, tensors, params, grads, vars):
50
+ amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
51
+ beta = self.get_settings('beta', params=params, cls=NumberList)
52
+
53
+ if amsgrad:
54
+ exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
55
+ else:
56
+ exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
57
+ max_exp_avg_sq = None
58
+
59
+ return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
60
+
61
+ class SqrtEMASquared(Transform):
62
+ SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
63
+
64
+ def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update',):
65
+ defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
66
+ super().__init__(defaults, uses_grad=False, target=target)
67
+
68
+
69
+ @torch.no_grad
70
+ def transform(self, tensors, params, grads, vars):
71
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
+
73
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
74
+ beta = self.get_settings('beta', params=params, cls=NumberList)
75
+
76
+ if amsgrad:
77
+ exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
78
+ else:
79
+ exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
80
+ max_exp_avg_sq = None
81
+
82
+ return self.SQRT_EMA_SQ_FN(
83
+ TensorList(tensors),
84
+ exp_avg_sq_=exp_avg_sq,
85
+ beta=beta,
86
+ max_exp_avg_sq_=max_exp_avg_sq,
87
+ debiased=debiased,
88
+ step=step,
89
+ pow=pow,
90
+ )
91
+
92
+
93
+ class Debias(Transform):
94
+ def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
95
+ defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
96
+ super().__init__(defaults, uses_grad=False, target=target)
97
+
98
+ @torch.no_grad
99
+ def transform(self, tensors, params, grads, vars):
100
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
101
+
102
+ settings = self.settings[params[0]]
103
+ pow = settings['pow']
104
+ alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
105
+
106
+ return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
107
+
108
+ class Debias2(Transform):
109
+ def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
110
+ defaults = dict(beta=beta, pow=pow)
111
+ super().__init__(defaults, uses_grad=False, target=target)
112
+
113
+ @torch.no_grad
114
+ def transform(self, tensors, params, grads, vars):
115
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
116
+
117
+ pow = self.settings[params[0]]['pow']
118
+ beta = self.get_settings('beta', params=params, cls=NumberList)
119
+ return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
120
+
121
+ class CenteredEMASquared(Transform):
122
+ def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2, target: Target = 'update'):
123
+ defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
124
+ super().__init__(defaults, uses_grad=False, target=target)
125
+
126
+ @torch.no_grad
127
+ def transform(self, tensors, params, grads, vars):
128
+ amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
129
+ beta = self.get_settings('beta', params=params, cls=NumberList)
130
+
131
+ if amsgrad:
132
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
133
+ else:
134
+ exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
135
+ max_exp_avg_sq = None
136
+
137
+ return centered_ema_sq_(
138
+ TensorList(tensors),
139
+ exp_avg_=exp_avg,
140
+ exp_avg_sq_=exp_avg_sq,
141
+ beta=beta,
142
+ max_exp_avg_sq_=max_exp_avg_sq,
143
+ pow=pow,
144
+ ).clone()
145
+
146
+ class CenteredSqrtEMASquared(Transform):
147
+ def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update'):
148
+ defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
149
+ super().__init__(defaults, uses_grad=False, target=target)
150
+
151
+ @torch.no_grad
152
+ def transform(self, tensors, params, grads, vars):
153
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
154
+
155
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
156
+ beta = self.get_settings('beta', params=params, cls=NumberList)
157
+
158
+ if amsgrad:
159
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
160
+ else:
161
+ exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
162
+ max_exp_avg_sq = None
163
+
164
+ return sqrt_centered_ema_sq_(
165
+ TensorList(tensors),
166
+ exp_avg_=exp_avg,
167
+ exp_avg_sq_=exp_avg_sq,
168
+ beta=beta,
169
+ debiased=debiased,
170
+ step=step,
171
+ max_exp_avg_sq_=max_exp_avg_sq,
172
+ pow=pow,
173
+ )
@@ -0,0 +1,189 @@
1
+ from collections.abc import Sequence
2
+ from functools import partial
3
+ from operator import itemgetter
4
+ from typing import Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Target, Transform
9
+ from ...utils import NumberList, TensorList
10
+ from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
+ from .ema import EMASquared, SqrtEMASquared
12
+ from .momentum import nag_
13
+
14
+
15
+ def precentered_ema_sq_(
16
+ tensors: TensorList,
17
+ exp_avg_: TensorList,
18
+ exp_avg_sq_: TensorList,
19
+ beta1: float | NumberList,
20
+ beta2: float | NumberList,
21
+ step: int,
22
+ min_step: int,
23
+ pow: float,
24
+ max_exp_avg_sq_: TensorList | None,
25
+ ):
26
+ """
27
+ Squared EMA of (update - 1st EMA). Starts taking effect after `min_step` to avoid division by epsilon.
28
+
29
+ returns `exp_avg_sq_` or `max_exp_avg_sq_`.
30
+ """
31
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0, lerp=False)
32
+
33
+ if step < min_step: centered_update = tensors
34
+ else: centered_update = tensors - exp_avg_
35
+
36
+ exp_avg_sq_=ema_sq_(
37
+ centered_update,
38
+ exp_avg_sq_=exp_avg_sq_,
39
+ beta=beta2,
40
+ pow=pow,
41
+ max_exp_avg_sq_=max_exp_avg_sq_,
42
+ )
43
+ return exp_avg_sq_
44
+
45
+ class PrecenteredEMASquared(Transform):
46
+ def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
47
+ defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
48
+ super().__init__(defaults, uses_grad=False, target=target)
49
+ self.current_step = 0
50
+
51
+ @torch.no_grad
52
+ def transform(self, tensors, params, grads, vars):
53
+ self.current_step += 1
54
+
55
+ beta1, beta2 = self.get_settings('beta1','beta2', params=params, cls=NumberList)
56
+ amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(self.settings[params[0]])
57
+
58
+ if amsgrad:
59
+ exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
60
+ else:
61
+ exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
62
+ max_exp_avg_sq = None
63
+
64
+ return precentered_ema_sq_(
65
+ TensorList(tensors),
66
+ exp_avg_ = exp_avg,
67
+ exp_avg_sq_=exp_avg_sq,
68
+ beta1=beta1,
69
+ beta2=beta2,
70
+ step = self.current_step,
71
+ min_step=min_step,
72
+ pow=pow,
73
+ max_exp_avg_sq_=max_exp_avg_sq,
74
+ ).clone()
75
+
76
+
77
+ def nag_ema_sq_(
78
+ tensors: TensorList,
79
+ exp_avg_sq_: TensorList,
80
+ beta: float | NumberList,
81
+ max_exp_avg_sq_: TensorList | None,
82
+ pow: float,
83
+ lerp:bool=True,
84
+ ):
85
+ """
86
+ Nesterov EMA of squared tensors.
87
+
88
+ Returns `exp_avg_sq_` or `max_exp_avg_sq_`.
89
+ """
90
+ if pow == 1: tensors = tensors.abs()
91
+ elif pow%2 == 0: tensors = tensors.pow(pow)
92
+ else: tensors = tensors.pow(pow).abs()
93
+
94
+ exp_avg_sq_=nag_(tensors,velocity_=exp_avg_sq_,momentum=beta,dampening=0,lerp=lerp,)
95
+
96
+ # AMSGrad
97
+ if max_exp_avg_sq_ is not None:
98
+ max_exp_avg_sq_.maximum_(exp_avg_sq_)
99
+ exp_avg_sq_ = max_exp_avg_sq_
100
+
101
+ return exp_avg_sq_
102
+
103
+ def sqrt_nag_ema_sq_(
104
+ tensors: TensorList,
105
+ exp_avg_sq_: TensorList,
106
+ beta: float | NumberList,
107
+ max_exp_avg_sq_: TensorList | None,
108
+ debiased: bool,
109
+ step: int,
110
+ pow: float,
111
+ lerp:bool=False,
112
+ ):
113
+ """
114
+ Square root of nesterov EMA of squared tensors.
115
+
116
+ Returns new tensors.
117
+ """
118
+ return sqrt_ema_sq_(tensors=tensors,exp_avg_sq_=exp_avg_sq_,beta=beta,max_exp_avg_sq_=max_exp_avg_sq_,
119
+ pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
120
+
121
+ class NesterovEMASquared(EMASquared):
122
+ EMA_SQ_FN = staticmethod(nag_ema_sq_)
123
+
124
+ class SqrtNesterovEMASquared(SqrtEMASquared):
125
+ SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
126
+
127
+
128
+ def coordinate_momentum_(
129
+ tensors: TensorList,
130
+ velocity_: TensorList,
131
+ p: float | NumberList,
132
+ ):
133
+ """
134
+ sets `velocity_` to p% random values from `tensors`.
135
+
136
+ Returns `velocity_`
137
+ """
138
+ mask = tensors.bernoulli_like(p).as_bool()
139
+ velocity_.masked_set_(mask, tensors)
140
+ return velocity_
141
+
142
+
143
+ class CoordinateMomentum(Transform):
144
+ def __init__(self, p: float = 0.1, target: Target = 'update'):
145
+ defaults = dict(p=p)
146
+ super().__init__(defaults, uses_grad=False, target=target)
147
+
148
+ @torch.no_grad
149
+ def transform(self, tensors, params, grads, vars):
150
+ p = self.get_settings('p', params=params, cls=NumberList)
151
+ velocity = self.get_state('velocity', params=params, cls=TensorList)
152
+ return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
153
+
154
+
155
+ # def multiplicative_momentum_(
156
+ # tensors_: TensorList,
157
+ # velocity_: TensorList,
158
+ # momentum: float | NumberList,
159
+ # dampening: float | NumberList,
160
+ # normalize_velocity: bool = True,
161
+ # abs: bool = False,
162
+ # lerp: bool = False,
163
+ # ):
164
+ # """
165
+ # abs: if True, tracks momentum of absolute magnitudes.
166
+
167
+ # returns `tensors_`.
168
+ # """
169
+ # tensors_into_velocity = tensors_.abs() if abs else tensors_
170
+ # ema_(tensors_into_velocity, exp_avg_=velocity_, beta=momentum, dampening=0, lerp=lerp)
171
+
172
+ # if normalize_velocity: velocity_ = velocity_ / velocity_.std().add_(1e-8)
173
+ # return tensors_.mul_(velocity_.lazy_mul(1-dampening) if abs else velocity_.abs().lazy_mul_(1-dampening))
174
+
175
+
176
+ # class MultiplicativeMomentum(Transform):
177
+ # """sucks"""
178
+ # def __init__(self, momentum: float = 0.9, dampening: float = 0,normalize_velocity: bool = True, abs: bool = False, lerp: bool = False):
179
+ # defaults = dict(momentum=momentum, dampening=dampening, normalize_velocity=normalize_velocity,abs=abs, lerp=lerp)
180
+ # super().__init__(defaults, uses_grad=False)
181
+
182
+ # @torch.no_grad
183
+ # def transform(self, tensors, params, grads, vars):
184
+ # momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
185
+ # abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
186
+ # velocity = self.get_state('velocity', params=params, cls=TensorList)
187
+ # return multiplicative_momentum_(TensorList(target), velocity_=velocity, momentum=momentum, dampening=dampening,
188
+ # normalize_velocity=normalize_velocity,abs=abs,lerp=lerp)
189
+