torchzero 0.1.7__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -494
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.7.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -132
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.7.dist-info/METADATA +0 -120
  199. torchzero-0.1.7.dist-info/RECORD +0 -104
  200. torchzero-0.1.7.dist-info/top_level.txt +0 -1
@@ -1,343 +0,0 @@
1
- from typing import Literal
2
-
3
- from ...modules import (
4
- LR,
5
- SGD,
6
- Abs,
7
- Adam,
8
- Add,
9
- AddMagnitude,
10
- Cautious,
11
- Div,
12
- Divide,
13
- Grad,
14
- HeavyBall,
15
- Interpolate,
16
- Lerp,
17
- Multistep,
18
- NanToNum,
19
- NesterovMomentum,
20
- Normalize,
21
- Random,
22
- RDiv,
23
- Reciprocal,
24
- UseGradSign,
25
- WeightDecay,
26
- )
27
- from ...modules import RandomCoordinateMomentum as _RandomCoordinateMomentum
28
- from ...modules.experimental import GradMin as _GradMin
29
- from ...modules.experimental import (
30
- HVPDiagNewton as _HVPDiagNewton,
31
- )
32
- from ...modules.experimental import MinibatchRprop as _MinibatchRprop
33
- from ...modules.experimental import ReduceOutwardLR
34
- from ...random import Distributions
35
- from ..modular import Modular
36
-
37
-
38
- class HVPDiagNewton(Modular):
39
- """for experiments, unlikely to work well on most problems.
40
-
41
- explanation - this should approximate newton method with 2 backward passes, but only if hessian is purely diagonal"""
42
- def __init__(
43
- self,
44
- params,
45
- lr: float = 1e-1,
46
- eps: float = 1e-2,
47
- ):
48
- modules = [_HVPDiagNewton(eps = eps), LR(lr)]
49
- super().__init__(params, modules)
50
-
51
-
52
- class ReciprocalSGD(Modular):
53
- """for experiments, unlikely to work well on most problems.
54
-
55
- explanation - this basically uses normalized *1 / (gradient + eps)*."""
56
- def __init__(
57
- self,
58
- params,
59
- lr: float = 1e-2,
60
- eps: float = 1e-2,
61
- momentum: float = 0,
62
- dampening: float = 0,
63
- nesterov: bool = False,
64
- weight_decay: float = 0,
65
- decoupled=True,
66
- ):
67
- modules: list = [
68
- AddMagnitude(eps, add_to_zero=False),
69
- Reciprocal(),
70
- NanToNum(0,0,0),
71
- Normalize(1),
72
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
73
- LR(lr),
74
- ]
75
- if decoupled: modules.append(WeightDecay(weight_decay))
76
- else: modules.insert(0, WeightDecay(weight_decay))
77
-
78
- super().__init__(params, modules)
79
-
80
- class NoiseSign(Modular):
81
- """for experiments, unlikely to work well on most problems.
82
-
83
- explanation - uses random vector with gradient sign, and works quite well despite being completely random."""
84
- def __init__(
85
- self,
86
- params,
87
- lr: float = 1e-2,
88
- distribution: Distributions = 'normal',
89
- momentum: float = 0,
90
- dampening: float = 0,
91
- nesterov: bool = False,
92
- weight_decay: float = 0,
93
- decoupled=True,
94
- ):
95
- modules: list = [
96
- Random(1, distribution),
97
- UseGradSign(),
98
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
99
- LR(lr),
100
- ]
101
- if decoupled: modules.append(WeightDecay(weight_decay))
102
- else: modules.insert(2, WeightDecay(weight_decay))
103
-
104
- super().__init__(params, modules)
105
-
106
- class MomentumNumerator(Modular):
107
- """for experiments, unlikely to work well on most problems. (somewhat promising)
108
-
109
- explanation - momentum divided by gradient."""
110
- def __init__(
111
- self,
112
- params,
113
- lr: float = 1e-2,
114
- momentum: float = 0.9,
115
- nesterov: bool = True,
116
- eps: float = 1e-2,
117
- weight_decay: float = 0,
118
- decoupled=True, ):
119
-
120
- modules: list = [
121
- Divide(
122
- numerator = SGD(momentum = momentum, nesterov=nesterov),
123
- denominator=[Abs(), Add(eps)]
124
- ),
125
- Normalize(),
126
- LR(lr),
127
- ]
128
- if decoupled: modules.append(WeightDecay(weight_decay))
129
- else: modules.insert(0, WeightDecay(weight_decay))
130
- super().__init__(params, modules)
131
-
132
- class MomentumDenominator(Modular):
133
- """for experiments, unlikely to work well on most problems.
134
-
135
- explanation - gradient divided by normalized momentum."""
136
- def __init__(
137
- self,
138
- params,
139
- lr: float = 1e-2,
140
- momentum: float = 0.9,
141
- nesterov: bool = True,
142
- eps: float = 1e-2,
143
- weight_decay: float = 0,
144
- decoupled=True,
145
- ):
146
- modules: list = [
147
- Div([SGD(momentum=momentum, nesterov=nesterov), Abs(), Add(eps), Normalize(1)]),
148
- Normalize(),
149
- LR(lr),
150
- ]
151
- if decoupled: modules.append(WeightDecay(weight_decay))
152
- else: modules.insert(0, WeightDecay(weight_decay))
153
- super().__init__(params, modules)
154
-
155
-
156
- class ExaggeratedNesterov(Modular):
157
- """for experiments, unlikely to work well on most problems.
158
-
159
- explanation - exaggerates difference between heavyball and nesterov momentum."""
160
- def __init__(
161
- self,
162
- params,
163
- lr: float = 1e-2,
164
- momentum: float = 0.9,
165
- dampening: float = 0,
166
- strength: float = 5,
167
- weight_decay: float = 0,
168
- decoupled=True,
169
- ):
170
-
171
- modules: list = [
172
- Interpolate(HeavyBall(momentum, dampening), NesterovMomentum(momentum, dampening), strength),
173
- LR(lr),
174
- ]
175
- if decoupled: modules.append(WeightDecay(weight_decay))
176
- else: modules.insert(0, WeightDecay(weight_decay))
177
- super().__init__(params, modules)
178
-
179
- class ExtraCautiousAdam(Modular):
180
- """for experiments, unlikely to work well on most problems.
181
-
182
- explanation - caution with true backtracking."""
183
- def __init__(
184
- self,
185
- params,
186
- lr: float = 1,
187
- beta1: float = 0.9,
188
- beta2: float = 0.999,
189
- eps: float = 1e-8,
190
- amsgrad=False,
191
- normalize = False,
192
- c_eps = 1e-6,
193
- mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
194
- strength = 5,
195
- weight_decay: float = 0,
196
- decoupled=True,
197
- ):
198
- modules: list = [
199
- Adam(beta1, beta2, eps, amsgrad=amsgrad),
200
- Lerp(Cautious(normalize, c_eps, mode), strength),
201
- LR(lr),
202
- ]
203
- if decoupled: modules.append(WeightDecay(weight_decay))
204
- else: modules.insert(0, WeightDecay(weight_decay))
205
- super().__init__(params, modules)
206
-
207
- class InwardSGD(Modular):
208
- """for experiments, unlikely to work well on most problems.
209
-
210
- explanation - reduces lrs for updates that move weights away from 0."""
211
- def __init__(
212
- self,
213
- params,
214
- lr: float = 1e-3,
215
- momentum: float = 0,
216
- dampening: float = 0,
217
- nesterov: bool = False,
218
- mul = 0.5,
219
- use_grad=False,
220
- invert=False,
221
- weight_decay: float = 0,
222
- decoupled=True,
223
- ):
224
- modules: list = [
225
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
226
- LR(lr),
227
- ReduceOutwardLR(mul, use_grad, invert),
228
- ]
229
- if decoupled: modules.append(WeightDecay(weight_decay))
230
- else: modules.insert(0, WeightDecay(weight_decay))
231
- super().__init__(params, modules)
232
-
233
- class MultistepSGD(Modular):
234
- """for experiments, unlikely to work well on most problems.
235
-
236
- explanation - perform multiple steps per batch. Momentum applies to the total update over multiple step"""
237
- def __init__(
238
- self,
239
- params,
240
- lr: float = 1e-3,
241
- momentum: float = 0,
242
- dampening: float = 0,
243
- nesterov: bool = False,
244
- num_steps=2,
245
- weight_decay: float = 0,
246
- decoupled=True,
247
- ):
248
- # lr, lr_module = _get_baked_in_and_module_lr(lr, kwargs) # multistep must use lr
249
-
250
- modules: list = [
251
- Multistep(LR(lr), num_steps=num_steps),
252
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
253
- ]
254
- if decoupled: modules.append(WeightDecay(weight_decay))
255
- else: modules.insert(0, WeightDecay(weight_decay))
256
- super().__init__(params, modules)
257
-
258
-
259
- class MinibatchRprop(Modular):
260
- """
261
- for experiments, unlikely to work well on most problems.
262
-
263
- explanation: does 2 steps per batch, applies rprop rule on the second step.
264
- """
265
- def __init__(
266
- self,
267
- params,
268
- lr: float = 1,
269
- nplus: float = 1.2,
270
- nminus: float = 0.5,
271
- lb: float | None = 1e-6,
272
- ub: float | None = 50,
273
- backtrack=True,
274
- next_mode = 'continue',
275
- increase_mul = 0.5,
276
- weight_decay: float = 0,
277
- decoupled=True,
278
- ):
279
- modules: list = [
280
- _MinibatchRprop(nplus=nplus,nminus=nminus,lb=lb,ub=ub,backtrack=backtrack,next_mode=next_mode,increase_mul=increase_mul),
281
- LR(lr),
282
- ]
283
- if decoupled: modules.append(WeightDecay(weight_decay))
284
- else: modules.insert(0, WeightDecay(weight_decay))
285
- super().__init__(params, modules)
286
-
287
-
288
- class RandomCoordinateMomentum(Modular):
289
- """for experiments, unlikely to work well on most problems.
290
-
291
- Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
292
- This works but I don't know if it is any good.
293
-
294
- Args:
295
- params: iterable of parameters to optimize or dicts defining parameter groups.
296
- lr (float): learning rate (default: 1e-3).
297
- p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
298
- nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
299
-
300
- """
301
-
302
- def __init__(
303
- self,
304
- params,
305
- lr: float = 1e-3,
306
- p: float = 0.1,
307
- nesterov: bool = True,
308
- weight_decay: float = 0,
309
- decoupled=True,
310
- ):
311
- modules: list = [_RandomCoordinateMomentum(p, nesterov), LR(lr)]
312
- if decoupled: modules.append(WeightDecay(weight_decay))
313
- else: modules.insert(0, WeightDecay(weight_decay))
314
- super().__init__(params, modules)
315
-
316
- class GradMin(Modular):
317
- """for experiments, unlikely to work well on most problems.
318
-
319
- explanation - this uses gradient wrt sum of gradients + loss."""
320
-
321
- def __init__(
322
- self,
323
- params,
324
- lr: float = 1e-2,
325
- loss_term: float = 1,
326
- square: bool = False,
327
- maximize_grad: bool = False,
328
- momentum: float = 0,
329
- dampening: float = 0,
330
- nesterov: bool = False,
331
- weight_decay: float = 0,
332
- decoupled=True,
333
- ):
334
- modules: list = [
335
- _GradMin(loss_term, square, maximize_grad),
336
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
337
- LR(lr),
338
- ]
339
- if decoupled: modules.append(WeightDecay(weight_decay))
340
- else: modules.insert(0, WeightDecay(weight_decay))
341
- super().__init__(params, modules)
342
-
343
-
@@ -1,83 +0,0 @@
1
- from typing import Literal, Any
2
-
3
- import torch
4
-
5
- from ...core import OptimizerModule
6
- from ...modules import (SGD, LineSearches, NewtonFDM,
7
- get_line_search, LR, WrapClosure)
8
- from ...modules.experimental.subspace import Subspace, ProjNormalize, ProjAscentRay
9
- from ..modular import Modular
10
-
11
-
12
- class NewtonFDMRaySearch(Modular):
13
- """for experiments, unlikely to work well on most problems.
14
-
15
- explanation - like a fancy line search, instead of a line searches in a cone using FDM newton."""
16
- def __init__(
17
- self,
18
- params,
19
- lr = 1e-2,
20
- momentum:float = 0,
21
- weight_decay:float = 0,
22
- dampening: float = 0,
23
- nesterov:bool = False,
24
- n_rays = 3,
25
- eps = 1e-2,
26
- ray_width: float = 1e-1,
27
- line_search: LineSearches | None = 'brent'
28
- ):
29
- modules: list[Any] = [
30
- SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
31
- LR(lr),
32
- Subspace(NewtonFDM(eps = eps), ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
33
- ]
34
- if lr != 1:
35
- modules.append(LR(lr))
36
-
37
- if line_search is not None:
38
- modules.append(get_line_search(line_search))
39
-
40
- super().__init__(params, modules)
41
-
42
-
43
- class LBFGSRaySearch(Modular):
44
- """for experiments, unlikely to work well on most problems.
45
-
46
- explanation - like a fancy line search, instead of a line searches in a cone using LBFGS."""
47
- def __init__(
48
- self,
49
- params,
50
- lr = 1,
51
- momentum:float = 0,
52
- weight_decay:float = 0,
53
- dampening: float = 0,
54
- nesterov:bool = False,
55
- n_rays = 24,
56
- ray_width: float = 1e-1,
57
- max_iter: int = 20,
58
- max_eval: int | None = None,
59
- tolerance_grad: float = 1e-7,
60
- tolerance_change: float = 1e-9,
61
- history_size: int = 100,
62
- line_search_fn: str | Literal['strong_wolfe'] | None = None,
63
- ):
64
- lbfgs = WrapClosure(
65
- torch.optim.LBFGS,
66
- lr=lr,
67
- max_iter=max_iter,
68
- max_eval=max_eval,
69
- tolerance_grad=tolerance_grad,
70
- tolerance_change=tolerance_change,
71
- history_size=history_size,
72
- line_search_fn=line_search_fn,
73
- )
74
- modules: list[OptimizerModule] = [
75
- SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
76
- Subspace(lbfgs, ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
77
-
78
- ]
79
-
80
- super().__init__(params, modules)
81
-
82
-
83
-
@@ -1,18 +0,0 @@
1
- from .cautious import CautiousAdamW, CautiousLion, CautiousSGD
2
- from .optimizers import (
3
- GD,
4
- SGD,
5
- Adagrad,
6
- Adam,
7
- AdamW,
8
- Grams,
9
- LaplacianSmoothingSGD,
10
- Lion,
11
- NestedNesterov,
12
- NoisySGD,
13
- NormSGD,
14
- RMSProp,
15
- Rprop,
16
- SignSGD,
17
- )
18
- from .forward_gradient import ForwardGradient
@@ -1,158 +0,0 @@
1
- from typing import Literal
2
-
3
-
4
- from ...core import OptimizerModule
5
- from ...modules import Cautious, Adam, SGD, Lion, WeightDecay, LR
6
- from ..modular import Modular
7
-
8
-
9
- class CautiousAdamW(Modular):
10
- """Adam, but updates for parameters where update and gradient sign is inconsistent are negated.
11
-
12
- Args:
13
- params: iterable of parameters to optimize or dicts defining parameter groups.
14
- lr (float): learning rate (default: 1e-3).
15
- beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
16
- beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
17
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
18
- amsgrad (bool, optional):
19
- whether to use the AMSGrad variant of this algorithm from
20
- On the Convergence of Adam and Beyond (default: False).
21
- normalize (bool, optional):
22
- renormalize update after masking.
23
- only has effect when mode is 'zero'. Defaults to False.
24
- c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
25
- mode (str, optional):
26
- what to do with updates with inconsistent signs.
27
-
28
- "zero" - set them to zero (as in paper)
29
-
30
- "grad" - set them to the gradient
31
-
32
- "negate" - negate them (same as using update magnitude and gradient sign).
33
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
34
- decoupled (bool, optional):
35
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
36
- """
37
- def __init__(
38
- self,
39
- params,
40
- lr: float = 1e-3,
41
- beta1: float = 0.9,
42
- beta2: float = 0.999,
43
- eps: float = 1e-8,
44
- amsgrad=False,
45
- normalize = False,
46
- c_eps = 1e-6,
47
- mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
48
- weight_decay: float = 0,
49
- decoupled=True,
50
- ):
51
- modules: list[OptimizerModule] = [
52
- Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
53
- LR(lr),
54
- Cautious(normalize = normalize, eps = c_eps, mode = mode),
55
- ]
56
- if decoupled: modules.append(WeightDecay(weight_decay))
57
- else: modules.insert(0, WeightDecay(weight_decay))
58
- super().__init__(params, modules)
59
-
60
-
61
- class CautiousSGD(Modular):
62
- """SGD with momentum, but updates for parameters where update and gradient sign is inconsistent are negated.
63
-
64
- Args:
65
- params: iterable of parameters to optimize or dicts defining parameter groups.
66
- lr (float): learning rate (default: 1e-3).
67
- momentum (float, optional): momentum. Defaults to 0.
68
- dampening (float, optional): momentum dampening. Defaults to 0.
69
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
70
- nesterov (bool, optional):
71
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
72
- normalize (bool, optional):
73
- renormalize update after masking.
74
- only has effect when mode is 'zero'. Defaults to False.
75
- c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
76
- mode (str, optional):
77
- what to do with updates with inconsistent signs.
78
-
79
- "zero" - set them to zero (as in paper)
80
-
81
- "grad" - set them to the gradient
82
-
83
- "negate" - negate them (same as using update magnitude and gradient sign).
84
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
85
- decoupled (bool, optional):
86
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
87
- """
88
- def __init__(
89
- self,
90
- params,
91
- lr: float = 1e-3,
92
- momentum: float = 0.9,
93
- dampening: float = 0,
94
- nesterov: bool = True,
95
- c_eps = 1e-6,
96
- normalize = False,
97
- mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
98
- weight_decay: float = 0,
99
- decoupled=True,
100
- ):
101
- modules: list[OptimizerModule] = [
102
- SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
103
- LR(lr),
104
- Cautious(normalize = normalize, eps = c_eps, mode = mode),
105
- ]
106
-
107
- if decoupled: modules.append(WeightDecay(weight_decay))
108
- else: modules.insert(0, WeightDecay(weight_decay))
109
-
110
- super().__init__(params, modules)
111
-
112
-
113
- class CautiousLion(Modular):
114
- """Lion optimizer, but updates for parameters where update and gradient sign is inconsistent are negated.
115
-
116
- Args:
117
- params: iterable of parameters to optimize or dicts defining parameter groups.
118
- lr (float): learning rate (default: 1e-3).
119
- beta1 (float, optional): dampening for momentum. Defaults to 0.9.
120
- beta2 (float, optional): momentum factor. Defaults to 0.99.
121
- normalize (bool, optional):
122
- renormalize update after masking.
123
- only has effect when mode is 'zero'. Defaults to False.
124
- c_eps (float, optional): epsilon for normalization after applying cautioning mask. Defaults to 1e-6.
125
- mode (str, optional):
126
- what to do with updates with inconsistent signs.
127
-
128
- "zero" - set them to zero (as in paper)
129
-
130
- "grad" - set them to the gradient
131
-
132
- "negate" - negate them (same as using update magnitude and gradient sign).
133
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
134
- decoupled (bool, optional):
135
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
136
- """
137
- def __init__(
138
- self,
139
- params,
140
- lr: float = 1e-3,
141
- beta1: float = 0.9,
142
- beta2: float = 0.99,
143
- c_eps = 1e-6,
144
- normalize = False,
145
- mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
146
- weight_decay: float = 0,
147
- decoupled=True,
148
- ):
149
- modules: list[OptimizerModule] = [
150
- Lion(beta1, beta2),
151
- LR(lr),
152
- Cautious(normalize = normalize, eps = c_eps, mode = mode),
153
- ]
154
-
155
- if decoupled: modules.append(WeightDecay(weight_decay))
156
- else: modules.insert(0, WeightDecay(weight_decay))
157
-
158
- super().__init__(params, modules)
@@ -1,70 +0,0 @@
1
- from typing import Literal
2
-
3
- import torch
4
-
5
- from ...core import OptimizationVars, OptimizerModule
6
- from ...modules import ForwardGradient as _ForwardGradient, SGD, WeightDecay, LR
7
- from ...tensorlist import Distributions
8
- from ..modular import Modular
9
-
10
-
11
- class ForwardGradient(Modular):
12
- """
13
-
14
- Evaluates jacobian-vector product with a random vector using forward mode autodiff (torch.func.jvp), which is
15
- the true directional derivative in the direction of that vector.
16
-
17
- Args:
18
- params: iterable of parameters to optimize or dicts defining parameter groups.
19
- lr (float, optional): learning rate. Defaults to 1e-3.
20
- n_samples (int): number of forward gradients to evaluate and average.
21
- distribution (Distributions): distribution for random tangent vector.
22
- mode (str):
23
- "jvp" - uses forward mode AD, usually slightly slower than backward mode AD but uses significantly less memory.
24
-
25
- "grad" - evaluates gradient with `loss.backward()` which may be faster but uses all the memory, mainly useful for
26
- benchmarking as there is probably no point in forward gradient if full gradient is available.
27
-
28
- "fd" - uses finite difference to estimate JVP, doesn't require gradients to be known. Equivalent to randomized FDM.
29
-
30
- fd_eps (float, optional): epsilon for finite difference, only has effect if mode is "fd". Defaults to 1e-4.
31
- momentum (float, optional): momentum. Defaults to 0.
32
- dampening (float, optional): momentum dampening. Defaults to 0.
33
- nesterov (bool, optional):
34
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
35
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
36
- decoupled (bool, optional):
37
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
38
-
39
- Reference:
40
- Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022).
41
- Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
42
- https://arxiv.org/abs/2202.08587
43
- """
44
-
45
- def __init__(
46
- self,
47
- params,
48
- lr: float = 1e-3,
49
- n_samples: int = 1,
50
- distribution: Distributions = "normal",
51
- mode: Literal["jvp", "grad", "fd"] = "jvp",
52
- fd_eps: float = 1e-4,
53
- momentum: float = 0,
54
- dampening: float = 0,
55
- nesterov: bool = False,
56
- weight_decay: float = 0,
57
- decoupled=False,
58
- ):
59
- modules: list = [
60
- _ForwardGradient(
61
- n_samples=n_samples,
62
- distribution=distribution,
63
- mode=mode,
64
- fd_eps=fd_eps,
65
- ),
66
- SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
67
- LR(lr),
68
- ]
69
- if decoupled: modules.append(WeightDecay(weight_decay))
70
- super().__init__(params, modules)