torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,171 @@
1
+ from ...core import OptimizerModule
2
+
3
+
4
+ def _reset_stats_hook(optimizer, state):
5
+ for module in optimizer.unrolled_modules:
6
+ module: OptimizerModule
7
+ module.reset_stats()
8
+
9
+ class PeriodicSWA(OptimizerModule):
10
+ """Periodic Stochastic Weight Averaging.
11
+
12
+ Please put this module at the end, after all other modules.
13
+
14
+ The algorithm is as follows:
15
+
16
+ 1. perform `pswa_start` normal steps before starting PSWA.
17
+
18
+ 2. Perform multiple SWA iterations. On each iteration,
19
+ run SWA algorithm for `num_cycles` cycles,
20
+ and set weights to the weighted average before starting the next SWA iteration.
21
+
22
+ SWA iteration is as follows:
23
+
24
+ 1. perform `cycle_start` initial steps (can be 0)
25
+
26
+ 2. for `num_cycles`, after every `cycle_length` steps passed, update the weight average with current model weights.
27
+
28
+ 3. After `num_cycles` cycles passed, set model parameters to the weight average.
29
+
30
+ Args:
31
+ first_swa (int):
32
+ number of steps before starting PSWA, authors run PSWA starting from 40th epoch out ot 150 epochs in total.
33
+ cycle_length (int):
34
+ number of steps betwen updating the weight average. Authors update it once per epoch.
35
+ num_cycles (int):
36
+ Number of weight average updates before setting model weights to the average and proceding to the next cycle.
37
+ Authors use 20 (meaning 20 epochs since each cycle is 1 epoch).
38
+ cycle_start (int, optional):
39
+ number of steps at the beginning of each SWA period before updating the weight average (default: 0).
40
+ reset_stats (bool, optional):
41
+ if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
42
+ """
43
+ def __init__(self, pswa_start: int, cycle_length: int, num_cycles: int, cycle_start: int = 0, reset_stats:bool = True):
44
+
45
+ super().__init__({})
46
+ self.pswa_start = pswa_start
47
+ self.cycle_start = cycle_start
48
+ self.cycle_length = cycle_length
49
+ self.num_cycles = num_cycles
50
+ self._reset_stats = reset_stats
51
+
52
+
53
+ self.cur = 0
54
+ self.period_cur = 0
55
+ self.swa_cur = 0
56
+ self.n_models = 0
57
+
58
+ def step(self, state):
59
+ swa = None
60
+ params = self.get_params()
61
+ ret = self._update_params_or_step_with_next(state, params)
62
+
63
+ # start first period after `pswa_start` steps
64
+ if self.cur >= self.pswa_start:
65
+
66
+ # start swa after `cycle_start` steps in the current period
67
+ if self.period_cur >= self.cycle_start:
68
+
69
+ # swa updates on every `cycle_length`th step
70
+ if self.swa_cur % self.cycle_length == 0:
71
+ swa = self.get_state_key('swa') # initialized to zeros for simplicity
72
+ swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
73
+ self.n_models += 1
74
+
75
+ self.swa_cur += 1
76
+
77
+ self.period_cur += 1
78
+
79
+ self.cur += 1
80
+
81
+ # passed num_cycles in period, set model parameters to SWA
82
+ if self.n_models == self.num_cycles:
83
+ self.period_cur = 0
84
+ self.swa_cur = 0
85
+ self.n_models = 0
86
+
87
+ assert swa is not None # it's created above self.n_models += 1
88
+
89
+ params.set_(swa)
90
+ # add a hook that resets momentum, which also deletes `swa` in this module
91
+ if self._reset_stats: state.add_post_step_hook(_reset_stats_hook)
92
+
93
+ return ret
94
+
95
+ class CyclicSWA(OptimizerModule):
96
+ """Periodic SWA with cyclic learning rate. So it samples the weights, increases lr to `peak_lr`, samples the weights again,
97
+ decreases lr back to `init_lr`, and samples the weights last time. Then model weights are replaced with the average of the three sampled weights,
98
+ and next cycle starts. I made this due to a horrible misreading of the original SWA paper but it seems to work well.
99
+
100
+ Please put this module at the end, after all other modules.
101
+
102
+ Args:
103
+ cswa_start (int): number of steps before starting the first CSWA cycle.
104
+ cycle_length (int): length of each cycle in steps.
105
+ steps_between (int): number of steps between cycles.
106
+ init_lr (float, optional): initial and final learning rate in each cycle. Defaults to 0.
107
+ peak_lr (float, optional): peak learning rate of each cycle. Defaults to 1.
108
+ sample_all (float, optional): if True, instead of sampling 3 weights, it samples all weights in the cycle. Defaults to False.
109
+ reset_stats (bool, optional):
110
+ if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
111
+
112
+ """
113
+ def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1, sample_all = False, reset_stats: bool=True,):
114
+ defaults = dict(init_lr = init_lr, peak_lr = peak_lr)
115
+ super().__init__(defaults)
116
+ self.cswa_start = cswa_start
117
+ self.cycle_length = cycle_length
118
+ self.init_lr = init_lr
119
+ self.peak_lr = peak_lr
120
+ self.steps_between = steps_between
121
+ self.sample_all = sample_all
122
+ self._reset_stats = reset_stats
123
+
124
+ self.cur = 0
125
+ self.cycle_cur = 0
126
+ self.n_models = 0
127
+
128
+ self.cur_lr = self.init_lr
129
+
130
+ def step(self, state):
131
+ params = self.get_params()
132
+
133
+ # start first period after `cswa_start` steps
134
+ if self.cur >= self.cswa_start:
135
+
136
+ ascent = state.maybe_use_grad_(params)
137
+
138
+ # determine the lr
139
+ point = self.cycle_cur / self.cycle_length
140
+ init_lr, peak_lr = self.get_group_keys('init_lr', 'peak_lr')
141
+ if point < 0.5:
142
+ p2 = point*2
143
+ lr = init_lr * (1-p2) + peak_lr * p2
144
+ else:
145
+ p2 = (1 - point)*2
146
+ lr = init_lr * (1-p2) + peak_lr * p2
147
+
148
+ ascent *= lr
149
+ ret = self._update_params_or_step_with_next(state, params)
150
+
151
+ if self.sample_all or self.cycle_cur in (0, self.cycle_length, self.cycle_length // 2):
152
+ swa = self.get_state_key('swa')
153
+ swa.mul_(self.n_models).add_(params).div_(self.n_models + 1)
154
+ self.n_models += 1
155
+
156
+ if self.cycle_cur == self.cycle_length:
157
+ if not self.sample_all: assert self.n_models == 3, self.n_models
158
+ self.n_models = 0
159
+ self.cycle_cur = -1
160
+
161
+ params.set_(swa)
162
+ if self._reset_stats: state.add_post_step_hook(_reset_stats_hook)
163
+
164
+ self.cycle_cur += 1
165
+
166
+ else:
167
+ ret = self._update_params_or_step_with_next(state, params)
168
+
169
+ self.cur += 1
170
+
171
+ return ret
@@ -0,0 +1,10 @@
1
+ r"""
2
+ Ready to use optimizers.
3
+ """
4
+ from .modular import Modular
5
+ from .quasi_newton import *
6
+ from .zeroth_order import *
7
+ from .second_order import *
8
+ from .first_order import *
9
+ # from .wrappers.scipy import ScipyMinimize
10
+ from . import experimental
@@ -0,0 +1,20 @@
1
+ """Optimizers that I haven't tested and various (mostly stupid) ideas go there.
2
+ If something works well I will move it outside of experimental folder.
3
+ Otherwise all optimizers in this category should be considered unlikely to good for most tasks."""
4
+ from .experimental import (
5
+ HVPDiagNewton,
6
+ ExaggeratedNesterov,
7
+ ExtraCautiousAdam,
8
+ GradMin,
9
+ InwardSGD,
10
+ MinibatchRprop,
11
+ MomentumDenominator,
12
+ MomentumNumerator,
13
+ MultistepSGD,
14
+ RandomCoordinateMomentum,
15
+ ReciprocalSGD,
16
+ NoiseSign,
17
+ )
18
+
19
+
20
+ from .ray_search import NewtonFDMRaySearch, LBFGSRaySearch
@@ -0,0 +1,343 @@
1
+ from typing import Literal
2
+
3
+ from ...modules import (
4
+ LR,
5
+ SGD,
6
+ Abs,
7
+ Adam,
8
+ Add,
9
+ AddMagnitude,
10
+ Cautious,
11
+ Div,
12
+ Divide,
13
+ Grad,
14
+ HeavyBall,
15
+ Interpolate,
16
+ Lerp,
17
+ Multistep,
18
+ NanToNum,
19
+ NesterovMomentum,
20
+ Normalize,
21
+ Random,
22
+ RDiv,
23
+ Reciprocal,
24
+ UseGradSign,
25
+ WeightDecay,
26
+ )
27
+ from ...modules import RandomCoordinateMomentum as _RandomCoordinateMomentum
28
+ from ...modules.experimental import GradMin as _GradMin
29
+ from ...modules.experimental import (
30
+ HVPDiagNewton as _HVPDiagNewton,
31
+ )
32
+ from ...modules.experimental import MinibatchRprop as _MinibatchRprop
33
+ from ...modules.experimental import ReduceOutwardLR
34
+ from ...random import Distributions
35
+ from ..modular import Modular
36
+
37
+
38
+ class HVPDiagNewton(Modular):
39
+ """for experiments, unlikely to work well on most problems.
40
+
41
+ explanation - this should approximate newton method with 2 backward passes, but only if hessian is purely diagonal"""
42
+ def __init__(
43
+ self,
44
+ params,
45
+ lr: float = 1e-1,
46
+ eps: float = 1e-2,
47
+ ):
48
+ modules = [_HVPDiagNewton(eps = eps), LR(lr)]
49
+ super().__init__(params, modules)
50
+
51
+
52
+ class ReciprocalSGD(Modular):
53
+ """for experiments, unlikely to work well on most problems.
54
+
55
+ explanation - this basically uses normalized *1 / (gradient + eps)*."""
56
+ def __init__(
57
+ self,
58
+ params,
59
+ lr: float = 1e-2,
60
+ eps: float = 1e-2,
61
+ momentum: float = 0,
62
+ dampening: float = 0,
63
+ nesterov: bool = False,
64
+ weight_decay: float = 0,
65
+ decoupled=True,
66
+ ):
67
+ modules: list = [
68
+ AddMagnitude(eps, add_to_zero=False),
69
+ Reciprocal(),
70
+ NanToNum(0,0,0),
71
+ Normalize(1),
72
+ SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
73
+ LR(lr),
74
+ ]
75
+ if decoupled: modules.append(WeightDecay(weight_decay))
76
+ else: modules.insert(0, WeightDecay(weight_decay))
77
+
78
+ super().__init__(params, modules)
79
+
80
+ class NoiseSign(Modular):
81
+ """for experiments, unlikely to work well on most problems.
82
+
83
+ explanation - uses random vector with gradient sign, and works quite well despite being completely random."""
84
+ def __init__(
85
+ self,
86
+ params,
87
+ lr: float = 1e-2,
88
+ distribution: Distributions = 'normal',
89
+ momentum: float = 0,
90
+ dampening: float = 0,
91
+ nesterov: bool = False,
92
+ weight_decay: float = 0,
93
+ decoupled=True,
94
+ ):
95
+ modules: list = [
96
+ Random(1, distribution),
97
+ UseGradSign(),
98
+ SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
99
+ LR(lr),
100
+ ]
101
+ if decoupled: modules.append(WeightDecay(weight_decay))
102
+ else: modules.insert(2, WeightDecay(weight_decay))
103
+
104
+ super().__init__(params, modules)
105
+
106
+ class MomentumNumerator(Modular):
107
+ """for experiments, unlikely to work well on most problems. (somewhat promising)
108
+
109
+ explanation - momentum divided by gradient."""
110
+ def __init__(
111
+ self,
112
+ params,
113
+ lr: float = 1e-2,
114
+ momentum: float = 0.9,
115
+ nesterov: bool = True,
116
+ eps: float = 1e-2,
117
+ weight_decay: float = 0,
118
+ decoupled=True, ):
119
+
120
+ modules: list = [
121
+ Divide(
122
+ numerator = SGD(momentum = momentum, nesterov=nesterov),
123
+ denominator=[Abs(), Add(eps)]
124
+ ),
125
+ Normalize(),
126
+ LR(lr),
127
+ ]
128
+ if decoupled: modules.append(WeightDecay(weight_decay))
129
+ else: modules.insert(0, WeightDecay(weight_decay))
130
+ super().__init__(params, modules)
131
+
132
+ class MomentumDenominator(Modular):
133
+ """for experiments, unlikely to work well on most problems.
134
+
135
+ explanation - gradient divided by normalized momentum."""
136
+ def __init__(
137
+ self,
138
+ params,
139
+ lr: float = 1e-2,
140
+ momentum: float = 0.9,
141
+ nesterov: bool = True,
142
+ eps: float = 1e-2,
143
+ weight_decay: float = 0,
144
+ decoupled=True,
145
+ ):
146
+ modules: list = [
147
+ Div([SGD(momentum=momentum, nesterov=nesterov), Abs(), Add(eps), Normalize(1)]),
148
+ Normalize(),
149
+ LR(lr),
150
+ ]
151
+ if decoupled: modules.append(WeightDecay(weight_decay))
152
+ else: modules.insert(0, WeightDecay(weight_decay))
153
+ super().__init__(params, modules)
154
+
155
+
156
+ class ExaggeratedNesterov(Modular):
157
+ """for experiments, unlikely to work well on most problems.
158
+
159
+ explanation - exaggerates difference between heavyball and nesterov momentum."""
160
+ def __init__(
161
+ self,
162
+ params,
163
+ lr: float = 1e-2,
164
+ momentum: float = 0.9,
165
+ dampening: float = 0,
166
+ strength: float = 5,
167
+ weight_decay: float = 0,
168
+ decoupled=True,
169
+ ):
170
+
171
+ modules: list = [
172
+ Interpolate(HeavyBall(momentum, dampening), NesterovMomentum(momentum, dampening), strength),
173
+ LR(lr),
174
+ ]
175
+ if decoupled: modules.append(WeightDecay(weight_decay))
176
+ else: modules.insert(0, WeightDecay(weight_decay))
177
+ super().__init__(params, modules)
178
+
179
+ class ExtraCautiousAdam(Modular):
180
+ """for experiments, unlikely to work well on most problems.
181
+
182
+ explanation - caution with true backtracking."""
183
+ def __init__(
184
+ self,
185
+ params,
186
+ lr: float = 1,
187
+ beta1: float = 0.9,
188
+ beta2: float = 0.999,
189
+ eps: float = 1e-8,
190
+ amsgrad=False,
191
+ normalize = False,
192
+ c_eps = 1e-6,
193
+ mode: Literal['zero', 'grad', 'backtrack'] = 'zero',
194
+ strength = 5,
195
+ weight_decay: float = 0,
196
+ decoupled=True,
197
+ ):
198
+ modules: list = [
199
+ Adam(beta1, beta2, eps, amsgrad=amsgrad),
200
+ Lerp(Cautious(normalize, c_eps, mode), strength),
201
+ LR(lr),
202
+ ]
203
+ if decoupled: modules.append(WeightDecay(weight_decay))
204
+ else: modules.insert(0, WeightDecay(weight_decay))
205
+ super().__init__(params, modules)
206
+
207
+ class InwardSGD(Modular):
208
+ """for experiments, unlikely to work well on most problems.
209
+
210
+ explanation - reduces lrs for updates that move weights away from 0."""
211
+ def __init__(
212
+ self,
213
+ params,
214
+ lr: float = 1e-3,
215
+ momentum: float = 0,
216
+ dampening: float = 0,
217
+ nesterov: bool = False,
218
+ mul = 0.5,
219
+ use_grad=False,
220
+ invert=False,
221
+ weight_decay: float = 0,
222
+ decoupled=True,
223
+ ):
224
+ modules: list = [
225
+ SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
226
+ LR(lr),
227
+ ReduceOutwardLR(mul, use_grad, invert),
228
+ ]
229
+ if decoupled: modules.append(WeightDecay(weight_decay))
230
+ else: modules.insert(0, WeightDecay(weight_decay))
231
+ super().__init__(params, modules)
232
+
233
+ class MultistepSGD(Modular):
234
+ """for experiments, unlikely to work well on most problems.
235
+
236
+ explanation - perform multiple steps per batch. Momentum applies to the total update over multiple step"""
237
+ def __init__(
238
+ self,
239
+ params,
240
+ lr: float = 1e-3,
241
+ momentum: float = 0,
242
+ dampening: float = 0,
243
+ nesterov: bool = False,
244
+ num_steps=2,
245
+ weight_decay: float = 0,
246
+ decoupled=True,
247
+ ):
248
+ # lr, lr_module = _get_baked_in_and_module_lr(lr, kwargs) # multistep must use lr
249
+
250
+ modules: list = [
251
+ Multistep(LR(lr), num_steps=num_steps),
252
+ SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
253
+ ]
254
+ if decoupled: modules.append(WeightDecay(weight_decay))
255
+ else: modules.insert(0, WeightDecay(weight_decay))
256
+ super().__init__(params, modules)
257
+
258
+
259
+ class MinibatchRprop(Modular):
260
+ """
261
+ for experiments, unlikely to work well on most problems.
262
+
263
+ explanation: does 2 steps per batch, applies rprop rule on the second step.
264
+ """
265
+ def __init__(
266
+ self,
267
+ params,
268
+ lr: float = 1,
269
+ nplus: float = 1.2,
270
+ nminus: float = 0.5,
271
+ lb: float | None = 1e-6,
272
+ ub: float | None = 50,
273
+ backtrack=True,
274
+ next_mode = 'continue',
275
+ increase_mul = 0.5,
276
+ weight_decay: float = 0,
277
+ decoupled=True,
278
+ ):
279
+ modules: list = [
280
+ _MinibatchRprop(nplus=nplus,nminus=nminus,lb=lb,ub=ub,backtrack=backtrack,next_mode=next_mode,increase_mul=increase_mul),
281
+ LR(lr),
282
+ ]
283
+ if decoupled: modules.append(WeightDecay(weight_decay))
284
+ else: modules.insert(0, WeightDecay(weight_decay))
285
+ super().__init__(params, modules)
286
+
287
+
288
+ class RandomCoordinateMomentum(Modular):
289
+ """for experiments, unlikely to work well on most problems.
290
+
291
+ Only uses `p` random coordinates of the new update. Other coordinates remain from previous update.
292
+ This works but I don't know if it is any good.
293
+
294
+ Args:
295
+ params: iterable of parameters to optimize or dicts defining parameter groups.
296
+ lr (float): learning rate (default: 1e-3).
297
+ p (float, optional): probability to update velocity with a new weigh value. Defaults to 0.1.
298
+ nesterov (bool, optional): if False, update uses delayed momentum. Defaults to True.
299
+
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ params,
305
+ lr: float = 1e-3,
306
+ p: float = 0.1,
307
+ nesterov: bool = True,
308
+ weight_decay: float = 0,
309
+ decoupled=True,
310
+ ):
311
+ modules: list = [_RandomCoordinateMomentum(p, nesterov), LR(lr)]
312
+ if decoupled: modules.append(WeightDecay(weight_decay))
313
+ else: modules.insert(0, WeightDecay(weight_decay))
314
+ super().__init__(params, modules)
315
+
316
+ class GradMin(Modular):
317
+ """for experiments, unlikely to work well on most problems.
318
+
319
+ explanation - this uses gradient wrt sum of gradients + loss."""
320
+
321
+ def __init__(
322
+ self,
323
+ params,
324
+ lr: float = 1e-2,
325
+ loss_term: float = 1,
326
+ square: bool = False,
327
+ maximize_grad: bool = False,
328
+ momentum: float = 0,
329
+ dampening: float = 0,
330
+ nesterov: bool = False,
331
+ weight_decay: float = 0,
332
+ decoupled=True,
333
+ ):
334
+ modules: list = [
335
+ _GradMin(loss_term, square, maximize_grad),
336
+ SGD(momentum = momentum, dampening = dampening, weight_decay = 0, nesterov = nesterov),
337
+ LR(lr),
338
+ ]
339
+ if decoupled: modules.append(WeightDecay(weight_decay))
340
+ else: modules.insert(0, WeightDecay(weight_decay))
341
+ super().__init__(params, modules)
342
+
343
+
@@ -0,0 +1,83 @@
1
+ from typing import Literal, Any
2
+
3
+ import torch
4
+
5
+ from ...core import OptimizerModule
6
+ from ...modules import (SGD, LineSearches, NewtonFDM,
7
+ get_line_search, LR, WrapClosure)
8
+ from ...modules.experimental.subspace import Subspace, ProjNormalize, ProjAscentRay
9
+ from ..modular import Modular
10
+
11
+
12
+ class NewtonFDMRaySearch(Modular):
13
+ """for experiments, unlikely to work well on most problems.
14
+
15
+ explanation - like a fancy line search, instead of a line searches in a cone using FDM newton."""
16
+ def __init__(
17
+ self,
18
+ params,
19
+ lr = 1e-2,
20
+ momentum:float = 0,
21
+ weight_decay:float = 0,
22
+ dampening: float = 0,
23
+ nesterov:bool = False,
24
+ n_rays = 3,
25
+ eps = 1e-2,
26
+ ray_width: float = 1e-1,
27
+ line_search: LineSearches | None = 'brent'
28
+ ):
29
+ modules: list[Any] = [
30
+ SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
31
+ LR(lr),
32
+ Subspace(NewtonFDM(eps = eps), ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
33
+ ]
34
+ if lr != 1:
35
+ modules.append(LR(lr))
36
+
37
+ if line_search is not None:
38
+ modules.append(get_line_search(line_search))
39
+
40
+ super().__init__(params, modules)
41
+
42
+
43
+ class LBFGSRaySearch(Modular):
44
+ """for experiments, unlikely to work well on most problems.
45
+
46
+ explanation - like a fancy line search, instead of a line searches in a cone using LBFGS."""
47
+ def __init__(
48
+ self,
49
+ params,
50
+ lr = 1,
51
+ momentum:float = 0,
52
+ weight_decay:float = 0,
53
+ dampening: float = 0,
54
+ nesterov:bool = False,
55
+ n_rays = 24,
56
+ ray_width: float = 1e-1,
57
+ max_iter: int = 20,
58
+ max_eval: int | None = None,
59
+ tolerance_grad: float = 1e-7,
60
+ tolerance_change: float = 1e-9,
61
+ history_size: int = 100,
62
+ line_search_fn: str | Literal['strong_wolfe'] | None = None,
63
+ ):
64
+ lbfgs = WrapClosure(
65
+ torch.optim.LBFGS,
66
+ lr=lr,
67
+ max_iter=max_iter,
68
+ max_eval=max_eval,
69
+ tolerance_grad=tolerance_grad,
70
+ tolerance_change=tolerance_change,
71
+ history_size=history_size,
72
+ line_search_fn=line_search_fn,
73
+ )
74
+ modules: list[OptimizerModule] = [
75
+ SGD(momentum=momentum, weight_decay=weight_decay, dampening=dampening, nesterov=nesterov),
76
+ Subspace(lbfgs, ProjNormalize(ProjAscentRay(ray_width, n = n_rays))),
77
+
78
+ ]
79
+
80
+ super().__init__(params, modules)
81
+
82
+
83
+
@@ -0,0 +1,18 @@
1
+ from .cautious import CautiousAdamW, CautiousLion, CautiousSGD
2
+ from .optimizers import (
3
+ GD,
4
+ SGD,
5
+ Adagrad,
6
+ Adam,
7
+ AdamW,
8
+ Grams,
9
+ LaplacianSmoothingSGD,
10
+ Lion,
11
+ NestedNesterov,
12
+ NoisySGD,
13
+ NormSGD,
14
+ RMSProp,
15
+ Rprop,
16
+ SignSGD,
17
+ )
18
+ from .forward_gradient import ForwardGradient