torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,99 +1,342 @@
1
- from collections import abc
2
-
3
- import torch
4
-
5
- from ...tensorlist import TensorList, where
6
- from ...core import OptimizerModule
7
-
8
-
9
- def _bool_ones_like(x):
10
- return torch.ones_like(x, dtype=torch.bool)
11
-
12
- class Rprop(OptimizerModule):
13
- """
14
- Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
15
- or `nminus` if it did. Then the update is applied with the sign of the current gradient.
16
-
17
- Additionally, if gradient changes sign, the update for that weight is reverted.
18
- Next step, magnitude for that weight won't change.
19
-
20
- Compared to pytorch this also implements backtracking update when sign changes.
21
- To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
22
-
23
- Args:
24
- nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
25
- nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
26
- lb (float): minimum step size, can be None (default: 1e-6)
27
- ub (float): maximum step size, can be None (default: 50)
28
- backtrack (float):
29
- if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
30
- When this is False, this exactly matches pytorch Rprop. (default: True)
31
- alpha (float): learning rate (default: 1).
32
-
33
- reference
34
- *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
35
- The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
36
- """
37
- def __init__(
38
- self,
39
- nplus: float = 1.2,
40
- nminus: float = 0.5,
41
- lb: float | None = 1e-6,
42
- ub: float | None = 50,
43
- backtrack=True,
44
- alpha: float = 1,
45
- ):
46
- defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
47
- super().__init__(defaults)
48
- self.current_step = 0
49
- self.backtrack = backtrack
50
-
51
- @torch.no_grad
52
- def _update(self, vars, ascent):
53
- params = self.get_params()
54
-
55
- sign = ascent.sign_()
56
- nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
57
- prev, allowed, magnitudes = self.get_state_keys(
58
- 'prev', 'allowed', 'magnitudes',
59
- inits = [torch.zeros_like, _bool_ones_like, torch.zeros_like],
60
- params=params
61
- )
62
-
63
- # initialize on 1st step
64
- if self.current_step == 0:
65
- magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
66
- ascent = magnitudes * sign
67
- prev.copy_(ascent)
68
- self.current_step += 1
69
- return ascent
70
-
71
- mul = (sign * prev).mul_(allowed)
72
-
73
- sign_changed = mul < 0
74
- sign_same = mul > 0
75
- zeroes = mul == 0
76
-
77
- mul.fill_(1)
78
- mul.masked_fill_(sign_changed, nminus)
79
- mul.masked_fill_(sign_same, nplus)
80
-
81
- # multiply magnitudes based on sign change and clamp to bounds
82
- magnitudes.mul_(mul).clamp_(lb, ub)
83
-
84
- # revert update if sign changed
85
- if self.backtrack:
86
- ascent = sign.mul_(magnitudes)
87
- ascent.masked_set_(sign_changed, prev.neg_())
88
- else:
89
- ascent = sign.mul_(magnitudes * ~sign_changed)
90
-
91
- # update allowed to only have weights where last update wasn't reverted
92
- allowed.set_(sign_same | zeroes)
93
-
94
- prev.copy_(ascent)
95
- self.current_step += 1
96
- return ascent
97
-
98
-
99
-
1
+
2
+ import torch
3
+
4
+ from ...core import Module, Target, Transform
5
+ from ...utils import NumberList, TensorList, as_tensorlist
6
+
7
+
8
+ def _bool_ones_like(x):
9
+ return torch.ones_like(x, dtype=torch.bool)
10
+
11
+ def sign_consistency_lrs_(
12
+ tensors: TensorList,
13
+ prev_: TensorList,
14
+ lrs_: TensorList,
15
+ nplus: float | NumberList,
16
+ nminus: float | NumberList,
17
+ lb: float | NumberList,
18
+ ub: float | NumberList,
19
+ step: int,
20
+ ):
21
+ """returns `lrs_`"""
22
+ sign = tensors.sign()
23
+ if step == 0:
24
+ prev_.set_(sign)
25
+ return lrs_.clamp_(lb, ub)
26
+
27
+ mul = sign * prev_
28
+ prev_.set_(sign)
29
+
30
+ sign_changed = mul < 0
31
+ sign_same = mul > 0
32
+
33
+ mul.fill_(1)
34
+ mul.masked_fill_(sign_changed, nminus)
35
+ mul.masked_fill_(sign_same, nplus)
36
+
37
+ # multiply magnitudes based on sign change and clamp to bounds
38
+ lrs_.mul_(mul).clamp_(lb, ub)
39
+ return lrs_
40
+
41
+ def scale_by_sign_change_(
42
+ tensors_: TensorList,
43
+ cur: TensorList,
44
+ prev_: TensorList,
45
+ lrs_: TensorList,
46
+ nplus: float | NumberList,
47
+ nminus: float | NumberList,
48
+ lb: float | NumberList,
49
+ ub: float | NumberList,
50
+ step: int,
51
+ ):
52
+ """returns `tensors_`"""
53
+ lrs_ = sign_consistency_lrs_(cur,prev_=prev_,lrs_=lrs_,nplus=nplus,nminus=nminus,
54
+ lb=lb,ub=ub,step=step)
55
+ return tensors_.mul_(lrs_)
56
+
57
+ def backtrack_on_sign_change_(
58
+ tensors_: TensorList,
59
+ cur: TensorList,
60
+ prev_: TensorList,
61
+ backtrack: bool,
62
+ step: int
63
+ ):
64
+ """returns `tensors_`."""
65
+ if step == 0:
66
+ prev_.set_(cur)
67
+ return tensors_
68
+
69
+ # mask will be > 0 for parameters where both signs are the same
70
+ mask = (cur * prev_) < 0
71
+ if backtrack: tensors_.masked_set_(mask, prev_)
72
+ else: tensors_.select_set_(mask, 0)
73
+
74
+ prev_.set_(cur)
75
+ return tensors_
76
+
77
+ def rprop_(
78
+ tensors_: TensorList,
79
+ prev_: TensorList,
80
+ allowed_: TensorList,
81
+ magnitudes_: TensorList,
82
+ nplus: float | NumberList,
83
+ nminus: float | NumberList,
84
+ lb: float | NumberList,
85
+ ub: float | NumberList,
86
+ alpha: float | NumberList,
87
+ backtrack: bool,
88
+ step: int,
89
+ ):
90
+ """returns new tensors."""
91
+
92
+ sign = tensors_.sign_()
93
+
94
+ # initialize on 1st step
95
+ if step == 0:
96
+ magnitudes_.fill_(alpha).clamp_(lb, ub)
97
+ new_tensors = magnitudes_ * sign
98
+ prev_.copy_(new_tensors)
99
+ return new_tensors
100
+
101
+ mul = (sign * prev_).mul_(allowed_)
102
+
103
+ sign_changed = mul < 0
104
+ sign_same = mul > 0
105
+ zeroes = mul == 0
106
+
107
+ mul.fill_(1)
108
+ mul.masked_fill_(sign_changed, nminus)
109
+ mul.masked_fill_(sign_same, nplus)
110
+
111
+ # multiply magnitudes based on sign change and clamp to bounds
112
+ magnitudes_.mul_(mul).clamp_(lb, ub)
113
+
114
+ # revert update if sign changed
115
+ if backtrack:
116
+ new_tensors = sign.mul_(magnitudes_)
117
+ new_tensors.masked_set_(sign_changed, prev_.neg_())
118
+ else:
119
+ new_tensors = sign.mul_(magnitudes_ * ~sign_changed)
120
+
121
+ # update allowed to only have weights where last update wasn't reverted
122
+ allowed_.set_(sign_same | zeroes)
123
+
124
+ prev_.copy_(new_tensors)
125
+ return new_tensors
126
+
127
+
128
+
129
+ class Rprop(Transform):
130
+ """
131
+ Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
132
+ or `nminus` if it did. Then the update is applied with the sign of the current gradient.
133
+
134
+ Additionally, if gradient changes sign, the update for that weight is reverted.
135
+ Next step, magnitude for that weight won't change.
136
+
137
+ Compared to pytorch this also implements backtracking update when sign changes.
138
+ To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
139
+
140
+ Args:
141
+ nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
142
+ nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
143
+ lb (float): minimum step size, can be None (default: 1e-6)
144
+ ub (float): maximum step size, can be None (default: 50)
145
+ backtrack (float):
146
+ if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
147
+ When this is False, this exactly matches pytorch Rprop. (default: True)
148
+ alpha (float): initial per-parameter learning rate (default: 1).
149
+
150
+ reference
151
+ *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
152
+ The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
153
+ """
154
+ def __init__(
155
+ self,
156
+ nplus: float = 1.2,
157
+ nminus: float = 0.5,
158
+ lb: float = 1e-6,
159
+ ub: float = 50,
160
+ backtrack=True,
161
+ alpha: float = 1,
162
+ ):
163
+ defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
164
+ self.current_step = 0
165
+ super().__init__(defaults, uses_grad=False)
166
+
167
+ @torch.no_grad
168
+ def transform(self, tensors, params, grads, vars):
169
+ nplus, nminus, lb, ub, alpha = self.get_settings('nplus', 'nminus', 'lb', 'ub', 'alpha', params=params, cls=NumberList)
170
+ prev, allowed, magnitudes = self.get_state(
171
+ 'prev','allowed','magnitudes',
172
+ params=params,
173
+ init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
174
+ cls = TensorList,
175
+ )
176
+
177
+ target = rprop_(
178
+ tensors_ = as_tensorlist(tensors),
179
+ prev_ = prev,
180
+ allowed_ = allowed,
181
+ magnitudes_ = magnitudes,
182
+ nplus = nplus,
183
+ nminus = nminus,
184
+ lb = lb,
185
+ ub = ub,
186
+ alpha = alpha,
187
+ backtrack=self.settings[params[0]]['backtrack'],
188
+ step=self.current_step,
189
+ )
190
+
191
+ self.current_step += 1
192
+ return target
193
+
194
+
195
+ class ScaleLRBySignChange(Transform):
196
+ """
197
+ learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign,
198
+ or `nminus` if it did.
199
+
200
+ This is part of RProp update rule.
201
+
202
+ Args:
203
+ nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
204
+ nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
205
+ lb (float): lower bound for lr.
206
+ ub (float): upper bound for lr.
207
+ alpha (float): initial learning rate.
208
+
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ nplus: float = 1.2,
214
+ nminus: float = 0.5,
215
+ lb=1e-6,
216
+ ub=50.0,
217
+ alpha=1.0,
218
+ use_grad=False,
219
+ target: Target = "update",
220
+ ):
221
+ defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
222
+ super().__init__(defaults, uses_grad=use_grad, target=target)
223
+ self.current_step = 0
224
+
225
+ @torch.no_grad
226
+ def transform(self, tensors, params, grads, vars):
227
+ target = as_tensorlist(tensors)
228
+ use_grad = self.settings[params[0]]['use_grad']
229
+ if use_grad: cur = as_tensorlist(grads)
230
+ else: cur = target
231
+
232
+ nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
233
+ prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
234
+
235
+ if self.current_step == 0:
236
+ lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
237
+
238
+ target = scale_by_sign_change_(
239
+ tensors_ = target,
240
+ cur = cur,
241
+ prev_ = prev,
242
+ lrs_ = lrs,
243
+ nplus = nplus,
244
+ nminus = nminus,
245
+ lb = lb,
246
+ ub = ub,
247
+ step = self.current_step,
248
+ )
249
+ self.current_step += 1
250
+ return target
251
+
252
+ class BacktrackOnSignChange(Transform):
253
+ """Negates or undoes update for parameters where where gradient or update sign changes.
254
+
255
+ This is part of RProp update rule.
256
+
257
+ Args:
258
+ normalize (bool, optional): renormalize update after masking. Defaults to False.
259
+ eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
260
+ use_grad (bool, optional):
261
+ if True, tracks sign change of the gradient,
262
+ otherwise track sign change of the update. Defaults to True.
263
+ backtrack (bool, optional):
264
+ if True, undoes the update when sign changes, otherwise negates it.
265
+ Defaults to True.
266
+
267
+ """
268
+ def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
269
+ defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
270
+ super().__init__(defaults, uses_grad=use_grad)
271
+ self.current_step = 0
272
+
273
+ @torch.no_grad
274
+ def transform(self, tensors, params, grads, vars):
275
+ target = as_tensorlist(tensors)
276
+ settings = self.settings[params[0]]
277
+ use_grad = settings['use_grad']
278
+ backtrack = settings['backtrack']
279
+
280
+ if use_grad: cur = as_tensorlist(grads)
281
+ else: cur = target
282
+
283
+ target = backtrack_on_sign_change_(
284
+ tensors_ = target,
285
+ cur = cur,
286
+ prev_ = self.get_state('prev', params=params, cls=TensorList),
287
+ backtrack = backtrack,
288
+ step = self.current_step,
289
+ )
290
+
291
+ self.current_step += 1
292
+ return target
293
+
294
+ class SignConsistencyMask(Transform):
295
+ """0 if sign changed 1 otherwise"""
296
+ def __init__(self,target: Target = 'update'):
297
+ super().__init__({}, uses_grad=False, target = target)
298
+
299
+ @torch.no_grad
300
+ def transform(self, tensors, params, grads, vars):
301
+ prev = self.get_state('prev', params=params, cls=TensorList)
302
+ mask = prev.mul_(tensors).gt_(0)
303
+ prev.set_(tensors)
304
+ return mask
305
+
306
+
307
+ class SignConsistencyLRs(Transform):
308
+ """LR for each weight is increased when two consequtive update signs are the same, decreased otherwise. This returns the LRs themselves."""
309
+ def __init__(
310
+ self,
311
+ nplus: float = 1.2,
312
+ nminus: float = 0.5,
313
+ lb: float | None = 1e-6,
314
+ ub: float | None = 50,
315
+ alpha: float = 1,
316
+ target: Target = 'update'
317
+ ):
318
+ defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
319
+ super().__init__(defaults, uses_grad=False, target = target)
320
+ self.current_step = 0
321
+
322
+ @torch.no_grad
323
+ def transform(self, tensors, params, grads, vars):
324
+ target = as_tensorlist(tensors)
325
+ nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
326
+ prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
327
+
328
+ if self.current_step == 0:
329
+ lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
330
+
331
+ target = sign_consistency_lrs_(
332
+ tensors = target,
333
+ prev_ = prev,
334
+ lrs_ = lrs,
335
+ nplus = nplus,
336
+ nminus = nminus,
337
+ lb = lb,
338
+ ub = ub,
339
+ step = self.current_step,
340
+ )
341
+ self.current_step += 1
342
+ return target.clone()
@@ -0,0 +1,197 @@
1
+ from collections.abc import Sequence
2
+ from operator import itemgetter
3
+ from functools import partial
4
+ import numpy as np
5
+ import torch
6
+
7
+ from ...core import Chainable, Transform, apply
8
+ from ...utils.linalg import matrix_power_eigh
9
+ from ...utils import set_storage_
10
+
11
+
12
+ def update_shampoo_preconditioner_(
13
+ grad: torch.Tensor,
14
+ accumulators_: list[torch.Tensor | None],
15
+ preconditioners_: list[torch.Tensor | None],
16
+ step: int,
17
+ update_freq: int,
18
+ exp_override: int | None,
19
+ beta: float | None,
20
+ ):
21
+ for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
22
+ if accumulator is None: continue
23
+ assert preconditioner is not None
24
+
25
+ axes = list(range(i)) + list(range(i + 1, grad.ndim))
26
+ if beta is None: accumulator.add_(torch.tensordot(grad, grad, (axes, axes))) # pyright:ignore[reportArgumentType]
27
+ else: accumulator.lerp_(torch.tensordot(grad, grad, (axes, axes)), 1-beta) # pyright:ignore[reportArgumentType]
28
+
29
+ if step % update_freq == 0:
30
+ matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
31
+ set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
32
+
33
+
34
+ def apply_shampoo_preconditioner(
35
+ tensor: torch.Tensor,
36
+ preconditioners_: list[torch.Tensor | None],
37
+ decay: float | None,
38
+ ):
39
+ for i, preconditioner in enumerate(preconditioners_):
40
+ if preconditioner is None: continue
41
+ tensor = torch.tensordot(tensor, preconditioner, ([0], [0])) # pyright:ignore[reportArgumentType]
42
+ if decay is not None: preconditioner.mul_(decay)
43
+ return tensor
44
+
45
+
46
+ def update_diagonal_(grad: torch.Tensor, diagonal_accumulator_: torch.Tensor, beta: float | None):
47
+ if beta is None: diagonal_accumulator_.add_(grad.pow(2))
48
+ else: diagonal_accumulator_.mul_(beta).addcmul_(grad, grad, value=1-beta)
49
+
50
+ def apply_diagonal_(grad_: torch.Tensor, diagonal_accumulator_: torch.Tensor, decay: float | None, eps: float):
51
+ grad_.div_(diagonal_accumulator_.sqrt() + eps)
52
+ if decay is not None: diagonal_accumulator_.mul_(decay)
53
+ return grad_
54
+
55
+ def _merge_small_dims(tensor: torch.Tensor, max_dim: int):
56
+ """a safer merger"""
57
+ if tensor.ndim == 0: return tensor, None, None
58
+ sort_idxs = np.argsort(tensor.shape)
59
+ if tensor.shape[sort_idxs[0]] > max_dim:
60
+ return tensor, None, None
61
+
62
+ tensor = tensor.permute(*sort_idxs)
63
+ flatten_end_idx = 0
64
+ flat_sizes = []
65
+ flat_numel = 1
66
+ for i, size in enumerate(tensor.shape):
67
+ if flat_numel * size <= max_dim:
68
+ flatten_end_idx = i
69
+ flat_numel *= size
70
+ flat_sizes.append(size)
71
+ else:
72
+ break
73
+
74
+ if flatten_end_idx != 0:
75
+ tensor = tensor.flatten(end_dim=flatten_end_idx)
76
+
77
+ return tensor, flat_sizes, sort_idxs
78
+
79
+ def _unmerge_small_dims(tensor: torch.Tensor, flat_sizes: Sequence[int] | None, sort_idxs: np.ndarray | Sequence[int] | None):
80
+ if flat_sizes is None: return tensor
81
+ assert sort_idxs is not None
82
+ tensor = tensor.unflatten(0, flat_sizes)
83
+ return tensor.permute(*np.argsort(sort_idxs))
84
+
85
+
86
+ class Shampoo(Transform):
87
+ """Shampoo from Preconditioned Stochastic Tensor Optimization (https://arxiv.org/abs/1802.09568).
88
+
89
+ Args:
90
+ decay (float | None, optional): slowly decays preconditioners. Defaults to None.
91
+ beta (float | None, optional):
92
+ if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
93
+ matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
94
+ update_freq (int, optional): preconditioner update frequency. Defaults to 10.
95
+ exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to None.
96
+ merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
97
+ max_dim (int, optional): maximum dimension size for preconditioning. Defaults to 2_000.
98
+ precondition_1d (bool, optional): whether to precondition 1d tensors. Defaults to True.
99
+ adagrad_eps (float, optional): epsilon for adagrad division for tensors where shampoo can't be applied. Defaults to 1e-8.
100
+ inner (Chainable | None, optional):
101
+ module applied after updating preconditioners and before applying preconditioning.
102
+ For example if beta≈0.999 and `inner=tz.m.EMA(0.9)`, this becomes Adam with shampoo preconditioner (ignoring debiasing).
103
+ Defaults to None.
104
+ """
105
+ def __init__(
106
+ self,
107
+ decay: float | None = None,
108
+ beta: float | None = None,
109
+ reg: float = 1e-6,
110
+ update_freq: int = 10,
111
+ exp_override: int | None = None,
112
+ merge_small: bool = True,
113
+ max_dim: int = 2_000,
114
+ precondition_1d: bool = True,
115
+ adagrad_eps: float = 1e-8,
116
+ inner: Chainable | None = None,
117
+ ):
118
+ defaults = dict(decay=decay, beta=beta, reg=reg, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
119
+ super().__init__(defaults, uses_grad=False)
120
+
121
+ if inner is not None:
122
+ self.set_child('inner', inner)
123
+
124
+ def transform(self, tensors, params, grads, vars):
125
+ merged_target = [] # target with merged dims
126
+
127
+ # update preconditioners
128
+ for i,(p,t) in enumerate(zip(params, tensors)):
129
+ state = self.state[p]
130
+ settings = self.settings[p]
131
+ beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
132
+ 'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
133
+
134
+ if merge_small:
135
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
136
+ merged_target.append(t)
137
+
138
+ # initialize accumulators and preconditioners for each dim on 1st step
139
+ if 'accumulators' not in state:
140
+
141
+ if not precondition_1d and t.ndim <= 1:
142
+ state['accumulators'] = []
143
+
144
+ else:
145
+ state['accumulators'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
146
+ state['preconditioners'] = [torch.eye(s, dtype=t.dtype, device=t.device) if 1<s<max_dim else None for s in t.shape]
147
+
148
+ # either scalar parameter, 1d with precondition_1d=False, or too big, then basic diagonal preconditioner is used.
149
+ if len([i is not None for i in state['accumulators']]) == 0:
150
+ state['diagonal_accumulator'] = torch.zeros_like(t)
151
+
152
+ state['step'] = 0
153
+
154
+ # update preconditioners
155
+ if 'diagonal_accumulator' in state:
156
+ update_diagonal_(t, state['diagonal_accumulator'], beta)
157
+ else:
158
+ update_shampoo_preconditioner_(
159
+ t,
160
+ accumulators_=state['accumulators'],
161
+ preconditioners_=state['preconditioners'],
162
+ step=state['step'],
163
+ update_freq=update_freq,
164
+ exp_override=exp_override,
165
+ beta=beta,
166
+ )
167
+
168
+ # inner step
169
+ if 'inner' in self.children:
170
+ tensors = apply(self.children['inner'], tensors, params=params, grads=grads, vars=vars)
171
+
172
+ # have to merge small dims again
173
+ merged_target = [] # target with merged dims
174
+ for i,(p,t) in enumerate(zip(params, tensors)):
175
+ state = self.state[p]
176
+ settings = self.settings[p]
177
+ if settings['merge_small']:
178
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
179
+ merged_target.append(t)
180
+
181
+ # precondition
182
+ for i, (p, t) in enumerate(zip(params, merged_target)):
183
+ state = self.state[p]
184
+ settings = self.settings[p]
185
+ decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
186
+
187
+ if 'diagonal_accumulator' in state:
188
+ tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
189
+ else:
190
+ tensors[i] = apply_shampoo_preconditioner(t, preconditioners_=state['preconditioners'], decay=decay)
191
+
192
+ if merge_small:
193
+ tensors[i] = _unmerge_small_dims(tensors[i], state['flat_sizes'], state['sort_idxs'])
194
+
195
+ state['step'] += 1
196
+
197
+ return tensors