torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,134 +0,0 @@
1
- from collections.abc import Callable, Iterable
2
- import numpy as np
3
- import torch
4
-
5
- from ...core import OptimizerModule
6
-
7
- _Value = int | float | OptimizerModule | Iterable[OptimizerModule]
8
-
9
-
10
- class Sum(OptimizerModule):
11
- """calculates sum of multiple updates.
12
-
13
- Args:
14
- *modules:
15
- either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
16
- def __init__(
17
- self,
18
- *modules: _Value,
19
- ):
20
- super().__init__({})
21
-
22
- scalars = [i for i in modules if isinstance(i, (int,float))]
23
- self.scalar = sum(scalars) if len(scalars) > 0 else None
24
-
25
- for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
26
- self._set_child_(i, module)
27
-
28
- @torch.no_grad
29
- def step(self, vars):
30
- if len(self.children) == 1:
31
- vars.ascent = self.children[0].return_ascent(vars)
32
- if self.scalar is not None: vars.ascent += self.scalar
33
- return self._update_params_or_step_with_next(vars)
34
-
35
- sum = None
36
- for i, c in sorted(self.children.items(), key=lambda x: x[0]):
37
- if i == len(self.children) - 1: cur_state = vars
38
- else: cur_state = vars.copy(clone_ascent = True)
39
-
40
- if sum is None: sum = c.return_ascent(cur_state)
41
- else: sum += c.return_ascent(cur_state)
42
-
43
- if i != len(self.children) - 1: vars.update_attrs_(cur_state)
44
-
45
- assert sum is not None
46
- if self.scalar is not None: sum += self.scalar
47
- vars.ascent = sum
48
- return self._update_params_or_step_with_next(vars)
49
-
50
- class Mean(OptimizerModule):
51
- """calculates mean of multiple updates.
52
-
53
- Args:
54
- *modules:
55
- either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
56
-
57
- def __init__(
58
- self,
59
- *modules: _Value,
60
- ):
61
- super().__init__({})
62
-
63
- scalars = [i for i in modules if isinstance(i, (int,float))]
64
- self.scalar = sum(scalars) if len(scalars) > 0 else None
65
-
66
- self.n_values = len(modules)
67
-
68
- for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
69
- self._set_child_(i, module)
70
-
71
- @torch.no_grad
72
- def step(self, vars):
73
- if len(self.children) == 1:
74
- vars.ascent = self.children[0].return_ascent(vars)
75
- if self.scalar is not None: vars.ascent += self.scalar
76
- if self.n_values > 1: vars.ascent /= self.n_values
77
- return self._update_params_or_step_with_next(vars)
78
-
79
- sum = None
80
- for i, c in sorted(self.children.items(), key=lambda x: x[0]):
81
- if i == len(self.children) - 1: cur_state = vars
82
- else: cur_state = vars.copy(clone_ascent = True)
83
-
84
- if sum is None: sum = c.return_ascent(cur_state)
85
- else: sum += c.return_ascent(cur_state)
86
-
87
- if i != len(self.children) - 1: vars.update_attrs_(cur_state)
88
-
89
- assert sum is not None
90
- if self.scalar is not None: sum += self.scalar
91
- if self.n_values > 1: sum /= self.n_values
92
- vars.ascent = sum
93
- return self._update_params_or_step_with_next(vars)
94
-
95
- class Product(OptimizerModule):
96
- """calculates product of multiple updates.
97
-
98
- Args:
99
- *modules:
100
- either OptimizerModules or iterables of OptimizerModules to chain. Scalars are also allowed."""
101
-
102
- def __init__(
103
- self,
104
- *modules: _Value,
105
- ):
106
- super().__init__({})
107
-
108
- scalars = [i for i in modules if isinstance(i, (int,float))]
109
- self.scalar = np.prod(scalars).item() if len(scalars) > 0 else None
110
-
111
- for i,module in enumerate(i for i in modules if not isinstance(i, (int, float))):
112
- self._set_child_(i, module)
113
-
114
- @torch.no_grad
115
- def step(self, vars):
116
- if len(self.children) == 1:
117
- vars.ascent = self.children[0].return_ascent(vars)
118
- if self.scalar is not None: vars.ascent *= self.scalar
119
- return self._update_params_or_step_with_next(vars)
120
-
121
- prod = None
122
- for i, c in sorted(self.children.items(), key=lambda x: x[0]):
123
- if i == len(self.children) - 1: cur_state = vars
124
- else: cur_state = vars.copy(clone_ascent = True)
125
-
126
- if prod is None: prod = c.return_ascent(cur_state)
127
- else: prod *= c.return_ascent(cur_state)
128
-
129
- if i != len(self.children) - 1: vars.update_attrs_(cur_state)
130
-
131
- assert prod is not None
132
- if self.scalar is not None: prod *= self.scalar
133
- vars.ascent = prod
134
- return self._update_params_or_step_with_next(vars)
@@ -1,113 +0,0 @@
1
- from collections.abc import Iterable
2
- from operator import methodcaller
3
-
4
- import torch
5
-
6
- from ...core import OptimizerModule
7
- from ...tensorlist import TensorList
8
-
9
-
10
- class Operation(OptimizerModule):
11
- """Applies an operation to the ascent, supported operations:
12
-
13
- `abs`, `sign`, `sin`, `cos`, `tan`, `asin`, `acos`, `atan`, `sinh`, `cosh`,
14
- `tanh`, `log`, `log1p`, `log2`, `log10`, `erf`, `erfc`, `exp`, `neg`, `reciprocal`,
15
- `copy`, `zero`, `sqrt`, `floor`, `ceil`, `round`."""
16
- def __init__(self, operation: str):
17
- super().__init__({})
18
- self.operation = methodcaller(f'{operation}_')
19
-
20
- @torch.no_grad
21
- def _update(self, vars, ascent): return self.operation(ascent)
22
-
23
- class Reciprocal(OptimizerModule):
24
- """*1 / update*"""
25
- def __init__(self,):
26
- super().__init__({})
27
-
28
- @torch.no_grad()
29
- def _update(self, vars, ascent): return ascent.reciprocal_()
30
-
31
- class Negate(OptimizerModule):
32
- """minus update"""
33
- def __init__(self,):
34
- super().__init__({})
35
-
36
- @torch.no_grad()
37
- def _update(self, vars, ascent): return ascent.neg_()
38
-
39
-
40
- def sign_grad_(params: Iterable[torch.Tensor]):
41
- """Apply sign function to gradients of an iterable of parameters.
42
-
43
- Args:
44
- params (abc.Iterable[torch.Tensor]): an iterable of Tensors or a single Tensor.
45
- """
46
- TensorList(params).get_existing_grads().sign_()
47
-
48
- class Sign(OptimizerModule):
49
- """applies sign function to the update"""
50
- def __init__(self):
51
- super().__init__({})
52
-
53
- @torch.no_grad
54
- def _update(self, vars, ascent): return ascent.sign_()
55
-
56
- class Abs(OptimizerModule):
57
- """takes absolute values of the update."""
58
- def __init__(self):
59
- super().__init__({})
60
-
61
- @torch.no_grad
62
- def _update(self, vars, ascent): return ascent.abs_()
63
-
64
- class Sin(OptimizerModule):
65
- """applies sin function to the ascent"""
66
- def __init__(self):
67
- super().__init__({})
68
-
69
- @torch.no_grad
70
- def _update(self, vars, ascent): return ascent.sin_()
71
-
72
- class Cos(OptimizerModule):
73
- """applies cos function to the ascent"""
74
- def __init__(self):
75
- super().__init__({})
76
-
77
- @torch.no_grad
78
- def _update(self, vars, ascent): return ascent.cos_()
79
-
80
-
81
- class NanToNum(OptimizerModule):
82
- """Convert `nan`, `inf` and `-inf` to numbers.
83
-
84
- Args:
85
- nan (optional): the value to replace NaNs with. Default is zero.
86
- posinf (optional): if a Number, the value to replace positive infinity values with.
87
- If None, positive infinity values are replaced with the greatest finite value
88
- representable by input's dtype. Default is None.
89
- neginf (optional): if a Number, the value to replace negative infinity values with.
90
- If None, negative infinity values are replaced with the lowest finite value
91
- representable by input's dtype. Default is None.
92
- """
93
- def __init__(self, nan=None, posinf=None, neginf=None):
94
- super().__init__({})
95
- self.nan = nan
96
- self.posinf = posinf
97
- self.neginf = neginf
98
-
99
- @torch.no_grad()
100
- def _update(self, vars, ascent): return ascent.nan_to_num_(self.nan, self.posinf, self.neginf)
101
-
102
-
103
- class MagnitudePower(OptimizerModule):
104
- """Raises update to the `value` power, but preserves the sign when the power is odd."""
105
- def __init__(self, value: int | float):
106
- super().__init__({})
107
- self.value = value
108
-
109
- @torch.no_grad()
110
- def _update(self, vars, ascent):
111
- if self.value % 2 == 1: return ascent.pow_(self.value)
112
- return ascent.abs().pow_(self.value) * ascent.sign()
113
-
@@ -1,54 +0,0 @@
1
- import typing as T
2
-
3
- import torch
4
-
5
- from ...core import OptimizerModule
6
- from ..momentum.momentum import _heavyball_step, _nesterov_step_
7
-
8
- class SGD(OptimizerModule):
9
- """Same as `torch.optim.SGD` but as an optimizer module. Exactly matches `torch.optim.SGD`, except
10
- nesterov momentum additionally supports dampening, and negative momentum is allowed.
11
-
12
- Args:
13
- momentum (float, optional): momentum. Defaults to 0.
14
- dampening (float, optional): momentum dampening. Defaults to 0.
15
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
16
- nesterov (bool, optional):
17
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
18
- alpha (float, optional): learning rate. Defaults to 1.
19
- """
20
- def __init__(
21
- self,
22
- momentum: float = 0,
23
- dampening: float = 0,
24
- weight_decay: float = 0,
25
- nesterov: bool = False,
26
- alpha: float = 1,
27
- ):
28
-
29
- defaults = dict(alpha=alpha, momentum=momentum, dampening=dampening, weight_decay=weight_decay,)
30
- super().__init__(defaults)
31
- self.nesterov = nesterov
32
- self.current_step = 0
33
-
34
- @torch.no_grad
35
- def _update(self, vars, ascent):
36
- params = self.get_params()
37
- settings = self.get_all_group_keys()
38
-
39
- if any(i != 0 for i in settings['weight_decay']):
40
- ascent += params * settings['weight_decay']
41
-
42
- if any(i != 1 for i in settings['alpha']):
43
- ascent *= settings['alpha']
44
-
45
- if any(i != 0 for i in settings['momentum']):
46
- velocity = self.get_state_key('velocity', init = torch.zeros_like if self.nesterov else ascent)
47
- # consistency with pytorch which on first step only initializes momentum
48
- if self.current_step > 0 or self.nesterov:
49
- # nesterov step can be done in-place, polyak returns new direction
50
- if self.nesterov: _nesterov_step_(ascent, velocity, settings['momentum'], settings['dampening'])
51
- else: ascent = _heavyball_step(ascent, velocity, settings['momentum'], settings['dampening'])
52
-
53
- self.current_step += 1
54
- return ascent
@@ -1,2 +0,0 @@
1
- from .svd import Orthogonalize, orthogonalize_grad_
2
- from .newtonschulz import ZeropowerViaNewtonSchulz, zeropower_via_newtonschulz_, DualNormCorrection
@@ -1,159 +0,0 @@
1
- """
2
- Newton-Schulz iteration code is taken from https://github.com/KellerJordan/Muon
3
-
4
- Keller Jordan and Yuchen Jin and Vlado Boza and You Jiacheng and Franz Cecista and Laker Newhouse and Jeremy Bernstein.
5
- Muon: An optimizer for hidden layers in neural networks (2024). URL: https://kellerjordan.github.io/posts/muon
6
- """
7
- from collections.abc import Iterable
8
-
9
- import torch
10
-
11
- from ...core import OptimizerModule, _Targets
12
- # from ...utils.compile import maybe_compile
13
-
14
- def _zeropower_via_newtonschulz5(G, steps):
15
- """
16
- code from https://github.com/KellerJordan/Muon
17
-
18
- Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
19
- quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
20
- of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
21
- zero even beyond the point where the iteration no longer converges all the way to one everywhere
22
- on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
23
- where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
24
- performance at all relative to UV^T, where USV^T = G is the SVD.
25
- """
26
- assert len(G.shape) == 2
27
- a, b, c = (3.4445, -4.7750, 2.0315)
28
- X = G.bfloat16()
29
- if G.size(0) > G.size(1):
30
- X = X.T
31
-
32
- # Ensure spectral norm is at most 1
33
- X = X / (X.norm() + 1e-7)
34
- # Perform the NS iterations
35
- for _ in range(steps):
36
- A = X @ X.T
37
- B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
38
- X = a * X + B @ X
39
-
40
- if G.size(0) > G.size(1):
41
- X = X.T
42
-
43
- return X
44
-
45
- _compiled_zeropower_via_newtonschulz5 = torch.compile(_zeropower_via_newtonschulz5)
46
-
47
-
48
- def zeropower_via_newtonschulz_(params: Iterable[torch.Tensor], steps: int = 6, adaptive = False, compiled = True):
49
- """Uses newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
50
-
51
- This sets gradients in-place.
52
-
53
- Note that the Muon page says that embeddings and classifier heads should not be orthogonalized.
54
-
55
- The orthogonalization code is taken from https://github.com/KellerJordan/Muon
56
- Args:
57
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
58
- steps (int): The number of Newton-Schulz iterations to run. (6 is probably always enough).
59
- The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
60
- adaptive (bool, optional):
61
- Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to False.
62
- compiled (bool, optional):
63
- Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
64
-
65
-
66
- """
67
- if compiled: fn = _compiled_zeropower_via_newtonschulz5
68
- else: fn = _zeropower_via_newtonschulz5
69
- for p in params:
70
- if p.grad is not None and p.grad.ndim >= 2 and min(p.grad.shape) >= 2:
71
- G = p.grad.view(p.grad.shape[0], -1)
72
- X = fn(G, steps)
73
-
74
- if adaptive:
75
- # this is from https://github.com/leloykun/adaptive-muon
76
- X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
77
-
78
- p.grad = X.reshape_as(p.grad).to(p.grad, copy=False)
79
-
80
-
81
- class ZeropowerViaNewtonSchulz(OptimizerModule):
82
- """Uses Newton-Schulz iteration to compute the zeroth power / orthogonalization of gradients of an iterable of parameters.
83
-
84
- To disable orthogonalization for a parameter, put it into a parameter group with "newtonshultz" = False.
85
- The Muon page says that embeddings and classifier heads should not be orthogonalized.
86
-
87
- The orthogonalization code is taken from https://github.com/KellerJordan/Muon.
88
-
89
- Note that unlike this module, Muon also uses Adam for gradients that are not orthogonalized,
90
- so I'd still recommend using it. Maybe use `Wrap` to wrap it into a module (I will make muon
91
- with selectable modules to optimize non-muon params soon)
92
-
93
- However not using Adam, or putting Adam module after this to apply it to ALL updates, both seem
94
- to work quite well too.
95
-
96
- Args:
97
- ns_steps (int, optional):
98
- The number of Newton-Schulz iterations to run. (6 is probably always enough). Defaults to 6.
99
- adaptive (bool, optional):
100
- Enables adaptation to scale of gradients (from https://github.com/leloykun/adaptive-muon). Defaults to True.
101
- compiled (bool, optional):
102
- Uses compiled newton-Schulz iteration function. Faster but won't work on windows. Defaults to True.
103
- target (str, optional):
104
- determines what this module updates.
105
-
106
- "ascent" - it updates the ascent
107
-
108
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
109
-
110
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
111
- """
112
- def __init__(self, ns_steps = 6, adaptive = False, compiled=True, target:_Targets='ascent'):
113
- defaults = dict(newtonshultz = True, ns_steps=ns_steps, adaptive=adaptive)
114
- super().__init__(defaults, target=target)
115
-
116
- if compiled: self._zeropower_via_newtonschulz5 = _compiled_zeropower_via_newtonschulz5
117
- else: self._zeropower_via_newtonschulz5 = _zeropower_via_newtonschulz5
118
-
119
- def _update(self, vars, ascent):
120
- toggle, ns_steps, adaptive = self.get_group_keys('newtonshultz', 'ns_steps', 'adaptive', cls=list)
121
-
122
- for asc, enable, steps, ada in zip(ascent, toggle, ns_steps, adaptive):
123
- if enable and len([i for i in asc.shape if i > 1]) != 0:
124
- G = asc.view(asc.shape[0], -1)
125
- X = self._zeropower_via_newtonschulz5(G, steps)
126
-
127
- if ada:
128
- # this is from https://github.com/leloykun/adaptive-muon
129
- X = torch.einsum('ij,ij,ab->ab', G.type_as(X), X, X) # Adaptive scaling,`(G * X).sum() * X` == (G.T @ X).trace() * X
130
-
131
- asc.set_(X.reshape_as(asc).to(asc, copy=False)) # type:ignore
132
-
133
- return ascent
134
-
135
-
136
-
137
- class DualNormCorrection(OptimizerModule):
138
- """Dual norm correction from https://github.com/leloykun/adaptive-muon.
139
-
140
- Description from the page:
141
-
142
- Single-line modification to any (dualizer-based) optimizer that allows the optimizer to adapt to the scale of the gradients as they change during training.
143
- This is done by scaling the dualized gradient by the clipped dual norm of the original gradient.
144
- """
145
- def __init__(self, adaptive_scale_min: int | None = -1, adaptive_scale_max: int | None = 1):
146
- defaults = dict(adaptive_scale_min = adaptive_scale_min, adaptive_scale_max = adaptive_scale_max)
147
- super().__init__(defaults)
148
-
149
- def _update(self, vars, ascent):
150
- params = self.get_params()
151
- adaptive_scale_min, adaptive_scale_max = self.get_group_keys('adaptive_scale_min', 'adaptive_scale_max')
152
-
153
- for asc, grad, min, max in zip(ascent, vars.maybe_compute_grad_(params), adaptive_scale_min, adaptive_scale_max):
154
- if len([i for i in asc.shape if i > 1]) != 0:
155
- scale = torch.einsum('ij,ij->', grad.view(grad.shape[0], -1), asc.view(asc.shape[0], -1))
156
- if min is not None or max is not None: scale = scale.clip(min, max)
157
- asc *= scale
158
-
159
- return ascent
@@ -1,86 +0,0 @@
1
- """Orthogonalization code adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
2
-
3
- Tuddenham, M., Prügel-Bennett, A., & Hare, J. (2022).
4
- Orthogonalising gradients to speed up neural network optimisation. arXiv preprint arXiv:2202.07052.
5
- """
6
- import logging
7
- from collections.abc import Iterable, Sequence
8
-
9
- import torch
10
-
11
- from ...core import OptimizerModule, _Targets
12
-
13
- @torch.no_grad()
14
- def _orthogonalize_update_(updates: Sequence[torch.Tensor], toggle = None, warn_fail=True) -> None:
15
- """adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers"""
16
- if toggle is None: toggle = [True] * len(updates)
17
-
18
- # Orthogonalise the gradients using SVD
19
- for grad, orth in zip(updates, toggle):
20
- if orth and grad.ndim > 1:
21
- G: torch.Tensor = grad.view(grad.shape[0], -1)
22
- orth_G: torch.Tensor | None = None
23
- try:
24
- u, s, vt = torch.linalg.svd(G, full_matrices=False) # pylint:disable=not-callable
25
- orth_G = u @ vt
26
- except RuntimeError:
27
- # if warn: logging.warning('Failed to perform SVD, adding some noise.')
28
- try:
29
- u, s, v = torch.svd_lowrank(
30
- G,
31
- q=1, # assume rank is at least 1
32
- M=1e-4 * G.mean() * torch.randn_like(G))
33
- orth_G = u @ v.T
34
- except RuntimeError:
35
- if warn_fail: logging.error(('Failed to perform SVD with noise,'
36
- ' skipping gradient orthogonalisation'))
37
- if orth_G is not None:
38
- grad.set_(orth_G.reshape_as(grad)) # type:ignore
39
-
40
- return updates
41
-
42
- def orthogonalize_grad_(params: Iterable[torch.Tensor], warn_fail=False):
43
- """orthogonalizes gradients of an iterable of parameters.
44
-
45
- This updates gradients in-place.
46
-
47
- The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
48
- Args:
49
- params (abc.Iterable[torch.Tensor]): parameters that hold gradients to orthogonalize.
50
- warn_fail (bool, optional):
51
- whether to print a warning when orthogonalization fails, and gradients are not
52
- orthogonalized. Defaults to True.
53
- """
54
- grads = [p.grad for p in params if p.grad is not None]
55
- _orthogonalize_update_(grads, warn_fail=warn_fail)
56
-
57
- class Orthogonalize(OptimizerModule):
58
- """Orthogonalizes the update using SVD.
59
-
60
- To disable orthogonalization for a parameter, put it into a parameter group with "orth" = False.
61
-
62
- The orthogonalization code is adapted from https://github.com/MarkTuddenham/Orthogonal-Optimisers
63
-
64
- Tip: :py:class:`tz.m.ZeropowerViaNewtonSchulz` is a significantly faster version of this.
65
- Args:
66
- warn_fail (bool, optional):
67
- whether to print a warning when orthogonalization fails, and gradients are not
68
- orthogonalized. Defaults to True.
69
- target (str, optional):
70
- determines what this module updates.
71
-
72
- "ascent" - it updates the ascent
73
-
74
- "grad" - it updates the gradient (and sets `.grad` attributes to updated gradient).
75
-
76
- "closure" - it makes a new closure that sets the updated ascent to the .`grad` attributes.
77
- """
78
- def __init__(self, warn_fail=True, target: _Targets = 'ascent'):
79
- defaults = dict(orth = True)
80
- super().__init__(defaults, target = target)
81
- self.warn_fail = warn_fail
82
-
83
- def _update(self, vars, ascent):
84
- toggle = self.get_group_key('orth', cls=list)
85
- _orthogonalize_update_(ascent, toggle, self.warn_fail)
86
- return ascent
@@ -1,22 +0,0 @@
1
- r"""
2
- This includes regularization modules like weight decay.
3
- """
4
- from .dropout import Dropout
5
- from .noise import AddNoise, Random, add_noise_
6
- from .normalization import (
7
- Centralize,
8
- ClipNorm,
9
- ClipValue,
10
- Normalize,
11
- centralize_grad_,
12
- clip_grad_norm_,
13
- clip_grad_value_,
14
- normalize_grad_,
15
- )
16
- from .weight_decay import (
17
- WeightDecay,
18
- l1_regularize_,
19
- l2_regularize_,
20
- weight_decay_penalty,
21
- )
22
- from .ortho_grad import OrthoGrad, orthograd_
@@ -1,34 +0,0 @@
1
- import typing as T
2
- from collections import abc
3
-
4
- import torch
5
-
6
- from ...tensorlist import Distributions, TensorList
7
- from ...core import OptimizerModule
8
-
9
-
10
- class Dropout(OptimizerModule):
11
- """
12
- Applies dropout to the update - sets random elements to 0.
13
-
14
- This can be used to apply learning rate dropout, if put after other modules, or gradient dropout,
15
- if put first.
16
-
17
- Args:
18
- p (float, optional): probability to replace update value with zero. Defaults to 0.5.
19
-
20
- reference
21
- *Lin, H., Zeng, W., Zhuang, Y., Ding, X., Huang, Y., & Paisley, J. (2022).
22
- Learning rate dropout. IEEE Transactions on Neural Networks and Learning Systems,
23
- 34(11), 9029-9039.*
24
- """
25
- def __init__(self, p: float = 0.5):
26
- defaults = dict(p = p)
27
- super().__init__(defaults)
28
-
29
- @torch.no_grad
30
- def _update(self, vars, ascent):
31
- p = self.get_group_key('p')
32
-
33
- ascent *= ascent.bernoulli_like(p)
34
- return ascent
@@ -1,77 +0,0 @@
1
- from collections import abc
2
- from typing import Literal
3
-
4
- import torch
5
-
6
- from ...core import OptimizerModule
7
- from ...tensorlist import Distributions, TensorList, _Scalar, _ScalarSequence
8
-
9
-
10
- def add_noise_(
11
- grads: abc.Iterable[torch.Tensor],
12
- alpha: "_Scalar | _ScalarSequence" = 1e-2,
13
- distribution: Distributions = "normal",
14
- mode: Literal["absolute", "global", "param", "channel"] = "param",
15
- ):
16
- if not isinstance(grads, TensorList): grads = TensorList(grads)
17
- if mode == 'absolute':
18
- grads += grads.sample_like(alpha, distribution)
19
-
20
- elif mode == 'global':
21
- grads += grads.sample_like((grads.total_vector_norm(1)/grads.total_numel() * alpha).detach().cpu().item(), distribution) # type:ignore
22
-
23
- elif mode == 'param':
24
- grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
25
-
26
- elif mode == 'channel':
27
- grads = grads.unbind_channels()
28
- grads += grads.sample_like(grads.abs().mean()*alpha, distribution)
29
-
30
- class AddNoise(OptimizerModule):
31
- """Add noise to update. By default noise magnitude is relative to the mean of each parameter.
32
-
33
- Args:
34
- alpha (float, optional): magnitude of noise. Defaults to 1e-2.
35
- distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
36
- mode (str, optional):
37
- how to calculate noise magnitude.
38
-
39
- - "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
40
-
41
- - "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
42
-
43
- - "param": multiplies `alpha` by mean of each individual parameter (default).
44
-
45
- - "channel": multiplies `alpha` by mean of each channel of each parameter.
46
- """
47
-
48
- def __init__(
49
- self,
50
- alpha: float = 1.,
51
- distribution: Distributions = "normal",
52
- mode: Literal["absolute", "global", "param", "channel"] = "param",
53
- ):
54
- defaults = dict(alpha = alpha)
55
- super().__init__(defaults)
56
- self.distribution: Distributions = distribution
57
- self.mode: Literal["absolute", "global", "param", "channel"] = mode
58
-
59
- @torch.no_grad
60
- def _update(self, vars, ascent):
61
- alpha = self.get_group_key('alpha')
62
-
63
- add_noise_(ascent, alpha, self.distribution, self.mode)
64
- return ascent
65
-
66
- class Random(OptimizerModule):
67
- """uses a random vector as the update. The vector is completely random and isn't checked to be descent direction.
68
- This is therefore mainly useful in combination with other modules like Sum, Multiply, etc."""
69
- def __init__(self, alpha: float = 1, distribution: Distributions = "normal"):
70
- defaults = dict(alpha = alpha)
71
- super().__init__(defaults)
72
- self.distribution: Distributions = distribution
73
-
74
- @torch.no_grad
75
- def _update(self, vars, ascent):
76
- alpha = self.get_group_key('alpha')
77
- return ascent.sample_like(alpha, self.distribution)