torchzero 0.1.8__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.1.dist-info/METADATA +379 -0
  124. torchzero-0.3.1.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.1.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.1.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -1,570 +0,0 @@
1
- from collections.abc import Iterable
2
- from typing import Literal
3
-
4
- from ...modules import (
5
- LR,
6
- AddNoise,
7
- Centralize,
8
- Grad,
9
- HeavyBall,
10
- LineSearches, LaplacianSmoothing,
11
- NesterovMomentum,
12
- Normalize,
13
- Random,
14
- Sign,
15
- UseGradSign,
16
- WeightDecay,
17
- get_line_search,
18
- )
19
- from ...modules import SGD as _SGD
20
- from ...modules import Adagrad as _Adagrad
21
- from ...modules import Adam as _Adam
22
- from ...modules import Lion as _Lion
23
- from ...modules import RMSProp as _RMSProp
24
- from ...modules import Rprop as _Rprop
25
- from ...random.random import Distributions
26
- from ..modular import Modular
27
-
28
-
29
- class GD(Modular):
30
- """Gradient descent with armijo line search.
31
-
32
- Args:
33
- params: iterable of parameters to optimize or dicts defining parameter groups.
34
- lr (float): learning rate (default: 1).
35
- line_search (LineSearches | None, optional):
36
- line search type. Defaults to 'armijo'.
37
- """
38
- def __init__(
39
- self,
40
- params,
41
- lr: float = 1,
42
- line_search: LineSearches | None = 'armijo',
43
- ):
44
- modules: list = [LR(lr)]
45
- if line_search is not None: modules.append(get_line_search(line_search))
46
-
47
- super().__init__(params, *modules)
48
-
49
- class SGD(Modular):
50
- """Exactly matches `torch.optim.SGD`, except
51
- nesterov momentum additionally supports dampening, negative momentum is allowed,
52
- and weight decay supports decoupling.
53
-
54
- Args:
55
- params: iterable of parameters to optimize or dicts defining parameter groups.
56
- lr (float): learning rate (default: 1e-3).
57
- momentum (float, optional): momentum. Defaults to 0.
58
- dampening (float, optional): momentum dampening. Defaults to 0.
59
- nesterov (bool, optional):
60
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
61
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
62
- decoupled (bool, optional):
63
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
64
- """
65
- def __init__(
66
- self,
67
- params,
68
- lr: float = 1e-3,
69
- momentum: float = 0,
70
- dampening: float = 0,
71
- nesterov: bool = False,
72
- weight_decay: float = 0,
73
- decoupled=False,
74
- ):
75
- modules: list = [
76
- _SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
77
- LR(lr)
78
- ]
79
- if decoupled: modules.append(WeightDecay(weight_decay))
80
- super().__init__(params, modules)
81
-
82
-
83
- class SignSGD(Modular):
84
- """SGD that uses sign of the gradient, can act as a normalizer and improve stability.
85
-
86
- Args:
87
- params: iterable of parameters to optimize or dicts defining parameter groups.
88
- lr (float): learning rate (default: 1e-3).
89
- momentum (float, optional): momentum. Defaults to 0.
90
- dampening (float, optional): momentum dampening. Defaults to 0.
91
- nesterov (bool, optional):
92
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
93
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
94
- decoupled (bool, optional):
95
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
96
- """
97
- def __init__(
98
- self,
99
- params,
100
- lr: float = 1e-3,
101
- momentum: float = 0,
102
- dampening: float = 0,
103
- nesterov: bool = False,
104
- weight_decay: float = 0,
105
- decoupled=False,
106
- ):
107
- modules: list = [
108
- Sign(),
109
- _SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
110
- LR(lr),
111
- ]
112
- if decoupled: modules.append(WeightDecay(weight_decay))
113
- super().__init__(params, modules)
114
-
115
-
116
- class NormSGD(Modular):
117
- """SGD with gradient normalization and optionally centralization.
118
-
119
- Args:
120
- params: iterable of parameters to optimize or dicts defining parameter groups.
121
- lr (float):
122
- learning rate, gradients are normalized to this value.
123
- This can typically be 10 times bigger than normal SGD (default: 1e-1).
124
- centralize (bool, optional): whether to centralize gradients (default: True).
125
- norm_mode (str, optional):
126
- what to normalize.
127
-
128
- - "global": normalize the entire gradient, as if it was a single vector.
129
-
130
- - "param": normalize each param's gradient.
131
-
132
- - "channel": normalize gradient of each channel of each param (default).
133
- centralize_mode (str, optional): what to centralize (same options as `norm_mode`). Defaults to 'channel'.
134
- min_numel (int, optional):
135
- skips parameters with less than this many elements. This avoids the issue where
136
- parameters that have a single element always get set to the value of 1.
137
- Ignored when mode is 'global'. Defaults to 2.
138
- momentum (float, optional): momentum. Defaults to 0.
139
- dampening (float, optional): momentum dampening. Defaults to 0.
140
- nesterov (bool, optional):
141
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
142
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
143
- decoupled (bool, optional):
144
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
145
- """
146
- def __init__(
147
- self,
148
- params,
149
- lr: float = 1e-1,
150
- normalize=True,
151
- norm_mode: Literal["global", "param", "channel"] = 'channel',
152
- ord = 2,
153
- centralize=True,
154
- centralize_mode: Literal["global", "param", "channel"] = 'channel',
155
- min_numel=2,
156
- momentum: float = 0,
157
- dampening: float = 0,
158
- nesterov: bool = False,
159
- weight_decay: float = 0,
160
- decoupled=True,
161
- ):
162
- modules: list = [
163
- _SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
164
- LR(lr),
165
- ]
166
- if decoupled: modules.append(WeightDecay(weight_decay))
167
- if normalize: modules.insert(0, Normalize(1, mode=norm_mode, min_numel=min_numel, ord=ord))
168
- if centralize: modules.insert(0, Centralize(centralize_mode, min_numel=min_numel))
169
- super().__init__(params, modules)
170
-
171
-
172
- class NoisySGD(Modular):
173
- """SGD with noise added to gradients. The formula for noise magnitude is `alpha * mean(abs(grad))`.
174
-
175
- Args:
176
- params: iterable of parameters to optimize or dicts defining parameter groups.
177
- lr (float): learning rate (default: 1e-3)
178
- alpha (float, optional): magnitude of noise. Defaults to 1e-2.
179
- distribution (Distributions, optional): distribution of noise. Defaults to 'normal'.
180
- mode (str, optional):
181
- how to calculate noise magnitude.
182
-
183
- - "absolute": ignores gradient magnitude and always uses `alpha` as magnitude.
184
-
185
- - "global": multiplies `alpha` by mean of the entire gradient, as if it was a single vector.
186
-
187
- - "param": multiplies `alpha` by mean of each individual parameter (default).
188
-
189
- - "channel": multiplies `alpha` by mean of each channel of each parameter.
190
- momentum (float, optional): momentum. Defaults to 0.
191
- dampening (float, optional): momentum dampening. Defaults to 0.
192
- nesterov (bool, optional):
193
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
194
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
195
- decoupled (bool, optional):
196
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
197
- """
198
- def __init__(
199
- self,
200
- params,
201
- lr: float = 1e-3,
202
- alpha: float = 1,
203
- distribution: Distributions = 'normal',
204
- mode: Literal["absolute", "global", "param", "channel"] = "param",
205
- momentum: float = 0,
206
- dampening: float = 0,
207
- nesterov: bool = False,
208
- weight_decay: float = 0,
209
- decoupled=False,
210
- ):
211
-
212
- modules: list = [
213
- _SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
214
- AddNoise(alpha, distribution, mode),
215
- LR(lr),
216
- ]
217
- if decoupled: modules.append(WeightDecay(weight_decay))
218
- super().__init__(params, modules)
219
-
220
- class LaplacianSmoothingSGD(Modular):
221
- """SGD with laplacian smoothing.
222
-
223
- Args:
224
- params: iterable of parameters to optimize or dicts defining parameter groups.
225
- lr (float): learning rate (default: 1e-3)
226
- sigma (float, optional): controls the amount of smoothing. Defaults to 1.
227
- layerwise (bool, optional):
228
- If True, applies smoothing to each parameter's gradient separately,
229
- Otherwise applies it to all gradients, concatenated into a single vector. Defaults to True.
230
- min_numel (int, optional):
231
- minimum number of elements in a parameter to apply laplacian smoothing to.
232
- Only has effect if `layerwise` is True. Defaults to 4.
233
- momentum (float, optional): momentum. Defaults to 0.
234
- dampening (float, optional): momentum dampening. Defaults to 0.
235
- nesterov (bool, optional):
236
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
237
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
238
- decoupled (bool, optional):
239
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
240
-
241
- Reference:
242
- *Osher, S., Wang, B., Yin, P., Luo, X., Barekat, F., Pham, M., & Lin, A. (2022).
243
- Laplacian smoothing gradient descent. Research in the Mathematical Sciences, 9(3), 55.*
244
- """
245
- def __init__(
246
- self,
247
- params,
248
- lr: float = 1e-3,
249
- sigma: float = 1,
250
- layerwise: bool = True,
251
- min_numel: int = 4,
252
- momentum: float = 0,
253
- dampening: float = 0,
254
- nesterov: bool = False,
255
- weight_decay: float = 0,
256
- decoupled=False,
257
- ):
258
-
259
- modules: list = [
260
- LaplacianSmoothing(sigma=sigma, layerwise=layerwise,min_numel=min_numel),
261
- _SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
262
- LR(lr),
263
- ]
264
- if decoupled: modules.append(WeightDecay(weight_decay))
265
- else: modules.insert(0, WeightDecay(weight_decay))
266
- super().__init__(params, modules)
267
-
268
- class Adagrad(Modular):
269
- """Divides ascent direction by mean square root of the sum of all past ascent directions.
270
-
271
- Exactly matches `torch.optim.Adagrad`.
272
-
273
- Args:
274
- params: iterable of parameters to optimize or dicts defining parameter groups.
275
- lr (float): learning rate (default: 1e-3).
276
- lr_decay (float, optional): learning rate decay. Defaults to 0.
277
- initial_accumulator_value (float, optional): initial value of the sum of squares of gradients. Defaults to 0.
278
- eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-10.
279
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
280
- decoupled (bool, optional):
281
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
282
- """
283
-
284
- def __init__(
285
- self,
286
- params,
287
- lr: float = 1e-3,
288
- lr_decay: float = 0,
289
- initial_accumulator_value: float = 0,
290
- eps: float = 1e-10,
291
- weight_decay: float = 0,
292
- decoupled=False,
293
- ):
294
- modules: list = [
295
- _Adagrad(lr_decay = lr_decay, initial_accumulator_value = initial_accumulator_value, eps = eps),
296
- LR(lr),
297
- ]
298
- if decoupled: modules.append(WeightDecay(weight_decay))
299
- else: modules.insert(0, WeightDecay(weight_decay))
300
- super().__init__(params, modules)
301
-
302
- class Rprop(Modular):
303
- """
304
- Resilient propagation. The update magnitude gets multiplied by `nplus` if gradient didn't change the sign,
305
- or `nminus` if it did. Then the update is applied with the sign of the current gradient.
306
-
307
- Additionally, if gradient changes sign, the update for that weight is reverted.
308
- Next step, magnitude for that weight won't change.
309
-
310
- Compared to pytorch this also implements backtracking update when sign changes.
311
- To make this behave exactly the same as `torch.optim.Rprop`, set `backtrack` to False.
312
-
313
- Args:
314
- params: iterable of parameters to optimize or dicts defining parameter groups.
315
- lr (float): learning rate (default: 1e-3).
316
- nplus (float): multiplicative increase factor for when ascent didn't change sign (default: 1.2).
317
- nminus (float): multiplicative decrease factor for when ascent changed sign (default: 0.5).
318
- lb (float): minimum step size, can be None (default: 1e-6)
319
- ub (float): maximum step size, can be None (default: 50)
320
- backtrack (float):
321
- if True, when ascent sign changes, undoes last weight update, otherwise sets update to 0.
322
- When this is False, this exactly matches pytorch Rprop. (default: True)
323
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
324
- decoupled (bool, optional):
325
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
326
-
327
- reference
328
- *Riedmiller, M., & Braun, H. (1993, March). A direct adaptive method for faster backpropagation learning:
329
- The RPROP algorithm. In IEEE international conference on neural networks (pp. 586-591). IEEE.*
330
- """
331
- def __init__(
332
- self,
333
- params,
334
- lr: float = 1e-3,
335
- nplus: float = 1.2,
336
- nminus: float = 0.5,
337
- lb: float | None = 1e-6,
338
- ub: float | None = 50,
339
- backtrack=True,
340
- weight_decay: float = 0,
341
- decoupled=False,
342
- ):
343
- modules: list = [
344
- _Rprop(nplus = nplus, nminus = nminus, lb=lb, ub = ub, backtrack=backtrack),
345
- LR(lr),
346
- ]
347
- if decoupled: modules.append(WeightDecay(weight_decay))
348
- else: modules.insert(0, WeightDecay(weight_decay))
349
- super().__init__(params, modules)
350
-
351
- class RMSProp(Modular):
352
- """
353
- Divides ascent direction by running average of its mean square root.
354
-
355
- Exactly matches `torch.optim.RMSProp`, except momentum initialization is arbitrarily different.
356
-
357
- Args:
358
- params: iterable of parameters to optimize or dicts defining parameter groups.
359
- lr (float): learning rate (default: 1e-3).
360
- momentum (float, optional): momentum. Defaults to 0.
361
- alpha (float, optional):
362
- smoothing constant (decay of ascent mean square root running average).
363
- Defaults to 0.99.
364
- eps (float, optional): term added to the denominator to improve numerical stability. Defaults to 1e-8.
365
- centered (float, optional):
366
- if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance.
367
- Defaults to False.
368
- dampening (float, optional): momentum dampening. Defaults to 0.
369
- nesterov (bool, optional):
370
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
371
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
372
- decoupled (bool, optional):
373
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
374
-
375
- reference
376
- https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
377
- """
378
- def __init__(
379
- self,
380
- params,
381
- lr: float = 1e-2,
382
- momentum: float = 0,
383
- alpha: float = 0.99,
384
- eps: float = 1e-8,
385
- centered: bool = False,
386
- nesterov = False,
387
- dampening: float = 0,
388
- weight_decay: float = 0,
389
- decoupled=False,
390
- ):
391
- modules: list = [
392
- _RMSProp(smoothing = alpha, eps = eps, centered = centered,),
393
- _SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
394
- LR(lr),
395
- ]
396
- if decoupled: modules.append(WeightDecay(weight_decay))
397
- else: modules.insert(0, WeightDecay(weight_decay))
398
- super().__init__(params, modules)
399
-
400
- class Adam(Modular):
401
- """Adam. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`, except
402
- if `decoupled` is True, weight decay is truly decoupled and doesn't depend on LR.
403
-
404
- Args:
405
- params: iterable of parameters to optimize or dicts defining parameter groups.
406
- lr (float): learning rate (default: 1e-3).
407
- beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
408
- beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
409
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
410
- amsgrad (bool, optional):
411
- whether to use the AMSGrad variant of this algorithm from
412
- On the Convergence of Adam and Beyond (default: False).
413
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
414
- decoupled (bool, optional):
415
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
416
- """
417
- def __init__(
418
- self,
419
- params,
420
- lr: float = 1e-3,
421
- beta1: float = 0.9,
422
- beta2: float = 0.999,
423
- eps: float = 1e-8,
424
- amsgrad=False,
425
- weight_decay: float = 0,
426
- decoupled=True,
427
- ):
428
- modules: list = [
429
- _Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
430
- LR(lr),
431
- ]
432
- if decoupled: modules.append(WeightDecay(weight_decay))
433
- else: modules.insert(0, WeightDecay(weight_decay))
434
- super().__init__(params, modules)
435
-
436
- class AdamW(Adam):
437
- """AdamW. Combines momentum and RMSProp. Exactly matches `torch.optim.Adam`, except
438
- if `decoupled` is True, weight decay is truly decoupled and doesn't depend on LR.
439
-
440
- Args:
441
- params: iterable of parameters to optimize or dicts defining parameter groups.
442
- lr (float): learning rate (default: 1e-3).
443
- beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
444
- beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
445
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
446
- amsgrad (bool, optional):
447
- whether to use the AMSGrad variant of this algorithm from
448
- On the Convergence of Adam and Beyond (default: False).
449
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.01.
450
- decoupled (bool, optional):
451
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
452
- """
453
- def __init__(
454
- self,
455
- params,
456
- lr: float = 1e-3,
457
- beta1: float = 0.9,
458
- beta2: float = 0.999,
459
- eps: float = 1e-8,
460
- amsgrad=False,
461
- weight_decay: float = 1e-2,
462
- decoupled=True,
463
- ):
464
- super().__init__(params=params,lr=lr,beta1=beta1,beta2=beta2,eps=eps,amsgrad=amsgrad,weight_decay=weight_decay,decoupled=decoupled)
465
-
466
- class Grams(Modular):
467
- """Grams (Gradient Descent with Adaptive Momentum Scaling) from https://arxiv.org/abs/2412.17107v1.
468
- This is Adam but uses gradient sign.
469
- Args:
470
- params: iterable of parameters to optimize or dicts defining parameter groups.
471
- lr (float): learning rate (default: 1e-3).
472
- beta1 (float, optional): exponential decay rate of gradient moving average. Defaults to 0.9.
473
- beta2 (float, optional): exponential decay rate of squared gradient moving average. Defaults to 0.999.
474
- eps (float, optional): epsilon for numerical stability. Defaults to 1e-8.
475
- amsgrad (bool, optional):
476
- whether to use the AMSGrad variant of this algorithm from
477
- On the Convergence of Adam and Beyond (default: False).
478
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
479
- decoupled (bool, optional):
480
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
481
- """
482
- def __init__(
483
- self,
484
- params,
485
- lr: float = 1e-3,
486
- beta1: float = 0.9,
487
- beta2: float = 0.999,
488
- eps: float = 1e-8,
489
- amsgrad=False,
490
- weight_decay: float = 0,
491
- decoupled=True,
492
- ):
493
- modules: list = [
494
- _Adam(beta1 = beta1, beta2 = beta2, eps = eps, amsgrad = amsgrad),
495
- LR(lr),
496
- UseGradSign()
497
- ]
498
- if decoupled: modules.append(WeightDecay(weight_decay))
499
- else: modules.insert(0, WeightDecay(weight_decay))
500
- super().__init__(params, modules)
501
-
502
-
503
- class Lion(Modular):
504
- """Lion (EvoLved Sign Momentum) optimizer from https://arxiv.org/abs/2302.06675.
505
-
506
- Args:
507
- params: iterable of parameters to optimize or dicts defining parameter groups.
508
- lr (float): learning rate (default: 1e-3).
509
- beta1 (float, optional): dampening for momentum. Defaults to 0.9.
510
- beta2 (float, optional): momentum factor. Defaults to 0.99.
511
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
512
- decoupled (bool, optional):
513
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
514
- """
515
- def __init__(
516
- self,
517
- params,
518
- lr: float = 1e-3,
519
- beta1: float = 0.9,
520
- beta2: float = 0.99,
521
- weight_decay: float = 0,
522
- decoupled=True,
523
- ):
524
- modules: list = [
525
- _Lion(beta1, beta2),
526
- LR(lr)
527
- ]
528
- if decoupled: modules.append(WeightDecay(weight_decay))
529
- else: modules.insert(0, WeightDecay(weight_decay))
530
- super().__init__(params, modules)
531
-
532
-
533
-
534
- class NestedNesterov(Modular):
535
- """Chains multiple nesterov momentums. The default (0.5, 0.5) seems to work well.
536
-
537
- Args:
538
- params: iterable of parameters to optimize or dicts defining parameter groups.
539
- lr (float): learning rate (default: 1e-3).
540
- momentums (Iterable[float], optional): sequence of momentums. Defaults to (0.5, 0.5, 0.5).
541
- dampening (float | Iterable[float], optional):
542
- sequence of dampenings for each momentum, or a single float that is used
543
- for all momentums. Defaults to 0.
544
- nesterov (bool, optional):
545
- enables nesterov momentum, otherwise uses heavyball momentum. Defaults to True.
546
- weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
547
- decoupled (bool, optional):
548
- decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
549
-
550
- """
551
- def __init__(
552
- self,
553
- params,
554
- lr: float = 1e-3,
555
- momentums: Iterable[float] = (0.5, 0.5, 0.5),
556
- dampening: float | Iterable[float] = 0,
557
- nesterov=True,
558
- weight_decay: float = 0,
559
- decoupled=True,
560
- ):
561
- momentums = list(momentums)
562
- if isinstance(dampening, (int, float)): dampening = [dampening for _ in momentums]
563
-
564
- cls = NesterovMomentum if nesterov else HeavyBall
565
- modules: list = [cls(m, d) for m, d in zip(momentums, dampening)] + [LR(lr)]
566
-
567
- if decoupled: modules.append(WeightDecay(weight_decay))
568
- else: modules.insert(0, WeightDecay(weight_decay))
569
-
570
- super().__init__(params, modules)