torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -0,0 +1,222 @@
1
+ from operator import itemgetter
2
+ import math
3
+ import warnings
4
+ from collections.abc import Iterable, Sequence
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Modular, TensorwiseTransform, Target, Transform
10
+ from ...utils import enable_compilation
11
+
12
+
13
+ def reverse_dims(t:torch.Tensor):
14
+ return t.permute(*reversed(range(t.ndim)))
15
+
16
+ def _is_at_least_2d(p: torch.Tensor):
17
+ if (p.ndim >= 2) and (p.size(0) > 1) and (p.size(1) > 1): return True
18
+ return False
19
+
20
+ # stolen from:
21
+ # https://github.com/KellerJordan/Muon/blob/master/muon.py
22
+ @enable_compilation
23
+ def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int) -> torch.Tensor:
24
+ """
25
+ Applies to last 2 dims - so usually reverse_dims should be applied to G before and after.
26
+
27
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
28
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
29
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
30
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
31
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
32
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
33
+ performance at all relative to UV^T, where USV^T = G is the SVD.
34
+ """
35
+ assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
36
+ a, b, c = (3.4445, -4.7750, 2.0315)
37
+ X = G.bfloat16()
38
+ if G.size(-2) > G.size(-1):
39
+ X = X.mT
40
+
41
+ # Ensure spectral norm is at most 1
42
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
43
+ # Perform the NS iterations
44
+ for _ in range(steps):
45
+ A = X @ X.mT
46
+ B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
47
+ X = a * X + B @ X
48
+
49
+ if G.size(-2) > G.size(-1):
50
+ X = X.mT
51
+ return X
52
+
53
+ # stolen from https://github.com/MarkTuddenham/Orthogonal-Optimisers.
54
+ # Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
55
+ # Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
56
+ @torch.no_grad
57
+ def _svd_orthogonalize(G: torch.Tensor, warn_fail=True) -> torch.Tensor:
58
+ """
59
+ Applies to first 2 dims and isn't batched - rest of dimensions are flattened.
60
+ """
61
+ X = G.view(G.shape[0], -1)
62
+
63
+ t = False
64
+ if X.size(0) > X.size(1):
65
+ X = X.T
66
+ t = True
67
+
68
+ orth_X: torch.Tensor | None = None
69
+ try:
70
+ u, s, vt = torch.linalg.svd(X, full_matrices=False) # pylint:disable=not-callable
71
+ orth_X = u @ vt
72
+ except RuntimeError:
73
+ # if warn: logging.warning('Failed to perform SVD, adding some noise.')
74
+ try:
75
+ u, s, v = torch.svd_lowrank(
76
+ X,
77
+ q=1, # assume rank is at least 1
78
+ M=1e-4 * X.mean() * torch.randn_like(X))
79
+ orth_X = u @ v.T
80
+ except RuntimeError:
81
+ if warn_fail: warnings.warn(('Failed to perform SVD with noise,'
82
+ ' skipping gradient orthogonalisation'))
83
+ if orth_X is not None:
84
+ if t: orth_X = orth_X.T
85
+ return orth_X.view_as(G)
86
+
87
+ return G # fail
88
+
89
+
90
+ @torch.no_grad
91
+ def _dual_norm_correction(X: torch.Tensor, g: torch.Tensor, batch_first):
92
+ """batch first means it applies to last 2 dims, otherwise to 1st two dims"""
93
+ # this is from https://github.com/leloykun/adaptive-muon
94
+ # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
95
+ if batch_first: X = torch.einsum('...ij,...ij,...ab->...ab', g.type_as(X), X, X)
96
+ else: X = torch.einsum('ij...,ij...,ab...->ab...', g.type_as(X), X, X)
97
+ return X
98
+
99
+
100
+ # code from
101
+ # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
102
+ def adjust_lr_for_muon(lr, param_shape):
103
+ A, B = param_shape[:2]
104
+ # We adjust the learning rate and weight decay based on the size of the parameter matrix
105
+ # as describted in the paper
106
+ adjusted_ratio = 0.2 * math.sqrt(max(A, B))
107
+ adjusted_lr = lr * adjusted_ratio
108
+ return adjusted_lr
109
+
110
+ def _orthogonalize_tensor(
111
+ tensor: torch.Tensor,
112
+ steps: int = 5,
113
+ method: Literal["newton-schulz", "svd"] = "newton-schulz",
114
+ ):
115
+ if method == 'newton-schulz': return reverse_dims(zeropower_via_newtonschulz5(reverse_dims(tensor), steps)).type_as(tensor)
116
+ if method == 'svd': return _svd_orthogonalize(tensor, False)
117
+ raise ValueError(method)
118
+
119
+
120
+ def orthogonalize_grads_(
121
+ params: Iterable[torch.Tensor],
122
+ steps: int = 5,
123
+ dual_norm_correction=False,
124
+ method: Literal["newton-schulz", "svd"] = "newton-schulz",
125
+ ):
126
+ """Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
127
+
128
+ This sets gradients in-place. Applies along first 2 dims (expected to be `out_channels, in_channels`).
129
+
130
+ Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
131
+ Args:
132
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
133
+ steps (int, optional):
134
+ The number of Newton-Schulz iterations to run. Defaults to 5.
135
+ dual_norm_correction (bool, optional):
136
+ enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
137
+ method (str, optional):
138
+ Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
139
+ """
140
+ for p in params:
141
+ if (p.grad is not None) and _is_at_least_2d(p.grad):
142
+ X = _orthogonalize_tensor(p.grad, steps, method)
143
+ if dual_norm_correction: X = _dual_norm_correction(X, p.grad, batch_first=False)
144
+ p.grad.set_(X.view_as(p)) # pyright:ignore[reportArgumentType]
145
+
146
+
147
+
148
+ class Orthogonalize(TensorwiseTransform):
149
+ """Uses Newton-Schulz iteration or SVD to compute the zeroth power / orthogonalization of update along first 2 dims.
150
+
151
+ To disable orthogonalization for a parameter, put it into a parameter group with "orthogonalize" = False.
152
+ The Muon page says that embeddings and classifier heads should not be orthogonalized.
153
+ Usually only matrix parameters that are directly used in matmuls should be orthogonalized.
154
+
155
+ To make Muon, use Split with Adam on 1d params: TODO code example.
156
+
157
+ Args:
158
+ ns_steps (int, optional):
159
+ The number of Newton-Schulz iterations to run. Defaults to 5.
160
+ adjust_lr (bool, optional):
161
+ Enables LR adjustment based on parameter size from "Muon is Scalable for LLM Training". Defaults to False.
162
+ dual_norm_correction (bool, optional):
163
+ enables dual norm correction from https://github.com/leloykun/adaptive-muon. Defaults to False.
164
+ method (str, optional):
165
+ Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
166
+ target (str, optional):
167
+ what to set on vars.
168
+ """
169
+ def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
170
+ method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
171
+ defaults = dict(orthogonalize=True, ns_steps=ns_steps, dual_norm_correction=dual_norm_correction, adjust_lr=adjust_lr, method=method.lower())
172
+ super().__init__(uses_grad=False, defaults=defaults, target=target)
173
+
174
+ @torch.no_grad
175
+ def transform(self, tensor, param, grad, vars):
176
+ orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
177
+ 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(self.settings[param])
178
+
179
+ if not orthogonalize: return tensor
180
+
181
+ if _is_at_least_2d(tensor):
182
+
183
+ X = _orthogonalize_tensor(tensor, ns_steps, method)
184
+
185
+ if dual_norm_correction:
186
+ X = _dual_norm_correction(X, tensor, batch_first=False)
187
+
188
+ if adjust_lr:
189
+ X.mul_(adjust_lr_for_muon(1, param.shape))
190
+
191
+ return X.view_as(param)
192
+
193
+ return tensor
194
+
195
+
196
+ class DualNormCorrection(TensorwiseTransform):
197
+ """Dual norm correction for dualizer based optimizers (https://github.com/leloykun/adaptive-muon).
198
+ Orthogonalize already has this built in with the `dual_norm_correction` setting."""
199
+ def __init__(self, target: Target='update'):
200
+ super().__init__({}, uses_grad=True, target=target)
201
+
202
+ def transform(self, tensor, param, grad, vars):
203
+ assert grad is not None
204
+ if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
205
+ return _dual_norm_correction(tensor, grad, batch_first=False)
206
+ return tensor
207
+
208
+
209
+ class MuonAdjustLR(Transform):
210
+ """LR adjustment for Muon from "Muon is Scalable for LLM Training" (https://github.com/MoonshotAI/Moonlight/tree/master).
211
+ Orthogonalize already has this built in with the `adjust_lr` setting, however you might want to move this to be later in the chain."""
212
+ def __init__(self, alpha: float = 1, target: Target='update'):
213
+ defaults = dict(alpha=alpha)
214
+ super().__init__(defaults=defaults, uses_grad=False, target=target)
215
+
216
+ def transform(self, tensors, params, grads, vars):
217
+ alphas = self.get_settings('alpha', params=params)
218
+ tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
219
+ tensors = [i[0] for i in tensors_alphas]
220
+ a = [i[1] for i in alphas]
221
+ torch._foreach_mul_(tensors, a)
222
+ return tensors
@@ -0,0 +1,55 @@
1
+ from operator import itemgetter
2
+ import math
3
+ import warnings
4
+ from collections.abc import Iterable, Sequence
5
+ from typing import Literal
6
+
7
+ import torch
8
+
9
+ from ...core import Target, Transform
10
+ from ...utils import as_tensorlist
11
+
12
+ def orthograd_(params: Iterable[torch.Tensor], eps: float = 1e-30):
13
+ """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
14
+
15
+ Args:
16
+ params (abc.Iterable[torch.Tensor]): parameters that hold gradients to apply ⟂Grad to.
17
+ eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
18
+
19
+ reference
20
+ https://arxiv.org/abs/2501.04697
21
+ """
22
+ params = as_tensorlist(params).with_grad()
23
+ grad = params.grad
24
+ grad -= (params.dot(grad)/(params.dot(params) + eps)) * params
25
+
26
+
27
+ class OrthoGrad(Transform):
28
+ """Applies ⟂Grad - projects gradient of an iterable of parameters to be orthogonal to the weights.
29
+
30
+ Args:
31
+ eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
32
+ renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
33
+ target (Target, optional): what to set on vars. Defaults to 'update'.
34
+ """
35
+ def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
36
+ defaults = dict(eps=eps, renormalize=renormalize)
37
+ super().__init__(defaults, uses_grad=False, target=target)
38
+
39
+ def transform(self, tensors, params, grads, vars):
40
+ settings = self.settings[params[0]]
41
+ eps = settings['eps']
42
+ renormalize = settings['renormalize']
43
+
44
+ params = as_tensorlist(params)
45
+ target = as_tensorlist(tensors)
46
+
47
+ scale = params.dot(target)/(params.dot(params) + eps)
48
+ if renormalize:
49
+ norm = target.global_vector_norm()
50
+ target -= params * scale
51
+ target *= (norm / target.global_vector_norm())
52
+ return target
53
+
54
+ target -= params * scale
55
+ return target
@@ -1,51 +1,103 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList
6
- from ...core import OptimizerModule
7
-
8
-
9
- def _rmsprop_step_(ascent: TensorList, mean_sqr: TensorList, smoothing, eps: TensorList):
10
- mean_sqr.mul_(smoothing).addcmul_(ascent, ascent, value = 1 - smoothing)
11
- return ascent.div_(mean_sqr.sqrt().add_(eps))
12
-
13
- def _centered_rmsprop_step_(ascent: TensorList, mean_sqr: TensorList, mean: TensorList, smoothing, eps: TensorList):
14
- mean_sqr.mul_(smoothing).addcmul_(ascent, ascent, value = 1 - smoothing)
15
- mean.lerp_compat_(ascent, 1-smoothing)
16
- return ascent.div_(mean_sqr.addcmul(mean, mean, value=-1).sqrt_().add_(eps))
17
-
18
- class RMSProp(OptimizerModule):
19
- """
20
- Divides ascent direction by running average of its mean square root.
21
-
22
- Exactly matches `torch.optim.RMSProp`.
23
-
24
- Args:
25
- smoothing (float, optional):
26
- smoothing constant (decay of ascent mean square root running average).
27
- Defaults to 0.99.
28
- eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-8.
29
- centered (float, optional):
30
- if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance.
31
- Defaults to False.
32
-
33
- reference
34
- https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
35
- """
36
- def __init__(self, smoothing: float = 0.99, eps: float = 1e-8, centered=False):
37
-
38
- defaults = dict(smoothing = smoothing, eps = eps)
39
- super().__init__(defaults)
40
- self.centered = centered
41
-
42
- @torch.no_grad
43
- def _update(self, vars, ascent):
44
- settings = self.get_all_group_keys()
45
- if self.centered:
46
- mean, mean_sqr = self.get_state_keys('mean', 'mean_sqr')
47
- updated_direction = _centered_rmsprop_step_(ascent, mean_sqr, mean, settings['smoothing'], settings['eps'])
48
- else:
49
- mean_sqr = self.get_state_key('mean_sqr')
50
- updated_direction = _rmsprop_step_(ascent, mean_sqr, settings['smoothing'], settings['eps'])
51
- return updated_direction
1
+ from operator import itemgetter
2
+ from typing import Literal
3
+
4
+ import torch
5
+
6
+ from ...core import Module, Target, Transform, Chainable, Vars, apply
7
+ from ...utils import NumberList, TensorList
8
+ from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
9
+
10
+
11
+ def rmsprop_(
12
+ tensors_: TensorList,
13
+ exp_avg_sq_: TensorList,
14
+ smoothing: float | NumberList,
15
+ eps: float | NumberList,
16
+ debiased: bool,
17
+ step: int,
18
+ exp_avg_: TensorList | None = None,
19
+ max_exp_avg_sq_: TensorList | None = None,
20
+ pow: float = 2,
21
+
22
+ # inner args
23
+ inner: Module | None = None,
24
+ params: list[torch.Tensor] | None = None,
25
+ grads: list[torch.Tensor] | None = None,
26
+ vars: Vars | None = None,
27
+ ):
28
+ """returns `tensors_`"""
29
+ if exp_avg_ is not None:
30
+ sqrt_exp_avg_sq = sqrt_centered_ema_sq_(tensors=tensors_, exp_avg_=exp_avg_,
31
+ exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
32
+ beta=smoothing,debiased=debiased,step=step,pow=pow)
33
+ else:
34
+ sqrt_exp_avg_sq = sqrt_ema_sq_(tensors=tensors_,exp_avg_sq_=exp_avg_sq_,max_exp_avg_sq_=max_exp_avg_sq_,
35
+ beta=smoothing,debiased=debiased,step=step,pow=pow)
36
+
37
+ if inner is not None:
38
+ assert params is not None
39
+ tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
40
+
41
+ return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
42
+
43
+ class RMSprop(Transform):
44
+ """Divides graient by EMA of gradient squares. Matches pytorch RMSprop if "init" is set to "zeros".
45
+
46
+ Args:
47
+ smoothing (float, optional): beta for exponential moving average of gradient squares. Defaults to 0.99.
48
+ eps (float, optional): epsilon for division. Defaults to 1e-8.
49
+ centered (bool, optional): whether to center EMA of gradient squares using an additional EMA. Defaults to False.
50
+ debiased (bool, optional): applies Adam debiasing. Defaults to False.
51
+ amsgrad (bool, optional): Whether to divide by maximum of EMA of gradient squares instead. Defaults to False.
52
+ pow (float, optional): power used in second momentum power and root. Defaults to 2.
53
+ init (str, optional): how to initialize EMA, either "update" to use first update or "zeros". Defaults to "update".
54
+ inner (Chainable | None, optional): Inner modules that are applied after updating EMA and before preconditioning. Defaults to None.
55
+ """
56
+ def __init__(
57
+ self,
58
+ smoothing: float = 0.99,
59
+ eps: float = 1e-8,
60
+ centered: bool = False,
61
+ debiased: bool = False,
62
+ amsgrad: bool = False,
63
+ pow: float = 2,
64
+ init: Literal["zeros", "update"] = "update",
65
+ inner: Chainable | None = None,
66
+ ):
67
+ defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
68
+ super().__init__(defaults=defaults, uses_grad=False)
69
+ self.current_step = 0
70
+ if inner is not None:
71
+ self.set_child('inner', inner)
72
+
73
+ def transform(self, tensors, params, grads, vars):
74
+ self.current_step += 1
75
+
76
+ smoothing,eps = self.get_settings('smoothing', 'eps', params=params, cls=NumberList)
77
+ centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
78
+
79
+ exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
80
+ exp_avg = self.get_state('exp_avg', params=params, cls=TensorList) if centered else None
81
+ max_exp_avg_sq = self.get_state('max_exp_avg_sq', params=params, cls=TensorList) if amsgrad else None
82
+
83
+ if init == 'update' and self.current_step == 1:
84
+ exp_avg_sq.set_([t**2 for t in tensors])
85
+ if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
86
+
87
+ return rmsprop_(
88
+ TensorList(tensors),
89
+ exp_avg_sq_=exp_avg_sq,
90
+ smoothing=smoothing,
91
+ eps=eps,
92
+ debiased=debiased,
93
+ step=self.current_step,
94
+ exp_avg_=exp_avg,
95
+ max_exp_avg_sq_=max_exp_avg_sq,
96
+ pow=pow,
97
+
98
+ # inner args
99
+ inner=self.children.get("inner", None),
100
+ params=params,
101
+ grads=grads,
102
+ vars=vars,
103
+ )