torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -3,104 +3,274 @@ from typing import Any
3
3
  from functools import partial
4
4
  import torch
5
5
 
6
- from ...utils import TensorList, Distributions, NumberList, generic_eq
6
+ from ...utils import TensorList, Distributions, NumberList
7
7
  from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
8
8
 
9
-
10
- def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
9
+ def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
11
10
  """p_fn is a function that returns the perturbation.
12
11
  It may return pre-generated one or generate one deterministically from a seed as in MeZO.
13
12
  Returned perturbation must be multiplied by `h`."""
14
- if v_0 is None: v_0 = closure(False)
13
+ if f_0 is None: f_0 = closure(False)
15
14
  params += p_fn()
16
- v_plus = closure(False)
15
+ f_1 = closure(False)
17
16
  params -= p_fn()
18
17
  h = h**2 # because perturbation already multiplied by h
19
- return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
18
+ return f_0, f_0, (f_1 - f_0) / h # (loss, loss_approx, grad)
20
19
 
21
- def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
22
- if v_0 is None: v_0 = closure(False)
20
+ def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
21
+ if f_0 is None: f_0 = closure(False)
23
22
  params -= p_fn()
24
- v_minus = closure(False)
23
+ f_m1 = closure(False)
25
24
  params += p_fn()
26
25
  h = h**2 # because perturbation already multiplied by h
27
- return v_0, v_0, (v_0 - v_minus) / h
26
+ return f_0, f_0, (f_0 - f_m1) / h
28
27
 
29
- def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: Any):
28
+ def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: Any):
30
29
  params += p_fn()
31
- v_plus = closure(False)
30
+ f_1 = closure(False)
32
31
 
33
32
  params -= p_fn() * 2
34
- v_minus = closure(False)
33
+ f_m1 = closure(False)
35
34
 
36
35
  params += p_fn()
37
36
  h = h**2 # because perturbation already multiplied by h
38
- return v_0, v_plus, (v_plus - v_minus) / (2 * h)
37
+ return f_0, f_1, (f_1 - f_m1) / (2 * h)
39
38
 
40
- def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
41
- if v_0 is None: v_0 = closure(False)
39
+ def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
40
+ if f_0 is None: f_0 = closure(False)
42
41
  params += p_fn()
43
- v_plus1 = closure(False)
42
+ f_1 = closure(False)
44
43
 
45
44
  params += p_fn()
46
- v_plus2 = closure(False)
45
+ f_2 = closure(False)
47
46
 
48
47
  params -= p_fn() * 2
49
48
  h = h**2 # because perturbation already multiplied by h
50
- return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
49
+ return f_0, f_0, (-3*f_0 + 4*f_1 - f_2) / (2 * h)
51
50
 
52
- def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
53
- if v_0 is None: v_0 = closure(False)
51
+ def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
52
+ if f_0 is None: f_0 = closure(False)
54
53
 
55
54
  params -= p_fn()
56
- v_minus1 = closure(False)
55
+ f_m1 = closure(False)
57
56
 
58
57
  params -= p_fn()
59
- v_minus2 = closure(False)
58
+ f_m2 = closure(False)
60
59
 
61
60
  params += p_fn() * 2
62
61
  h = h**2 # because perturbation already multiplied by h
63
- return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
62
+ return f_0, f_0, (f_m2 - 4*f_m1 + 3*f_0) / (2 * h)
64
63
 
65
- def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
64
+ def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
66
65
  params += p_fn()
67
- v_plus1 = closure(False)
66
+ f_1 = closure(False)
68
67
 
69
68
  params += p_fn()
70
- v_plus2 = closure(False)
69
+ f_2 = closure(False)
71
70
 
72
71
  params -= p_fn() * 3
73
- v_minus1 = closure(False)
72
+ f_m1 = closure(False)
74
73
 
75
74
  params -= p_fn()
76
- v_minus2 = closure(False)
75
+ f_m2 = closure(False)
76
+
77
+ params += p_fn() * 2
78
+ h = h**2 # because perturbation already multiplied by h
79
+ return f_0, f_1, (f_m2 - 8*f_m1 + 8*f_1 - f_2) / (12 * h)
80
+
81
+ # some good ones
82
+ # Pachalyl S. et al. Generalized simultaneous perturbation-based gradient search with reduced estimator bias //IEEE Transactions on Automatic Control. – 2025.
83
+ # Three measurements GSPSA is _rforward3
84
+ # Four measurements GSPSA
85
+ def _rforward4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
86
+ if f_0 is None: f_0 = closure(False)
87
+ params += p_fn()
88
+ f_1 = closure(False)
89
+
90
+ params += p_fn()
91
+ f_2 = closure(False)
92
+
93
+ params += p_fn()
94
+ f_3 = closure(False)
95
+
96
+ params -= p_fn() * 3
97
+ h = h**2 # because perturbation already multiplied by h
98
+ return f_0, f_0, (2*f_3 - 9*f_2 + 18*f_1 - 11*f_0) / (6 * h)
99
+
100
+ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
101
+ if f_0 is None: f_0 = closure(False)
102
+ params += p_fn()
103
+ f_1 = closure(False)
104
+
105
+ params += p_fn()
106
+ f_2 = closure(False)
107
+
108
+ params += p_fn()
109
+ f_3 = closure(False)
110
+
111
+ params += p_fn()
112
+ f_4 = closure(False)
113
+
114
+ params -= p_fn() * 4
115
+ h = h**2 # because perturbation already multiplied by h
116
+ return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
117
+
118
+ # another central4
119
+ def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
+ params += p_fn()
121
+ f_1 = closure(False)
77
122
 
78
123
  params += p_fn() * 2
124
+ f_3 = closure(False)
125
+
126
+ params -= p_fn() * 4
127
+ f_m1 = closure(False)
128
+
129
+ params -= p_fn() * 2
130
+ f_m3 = closure(False)
131
+
132
+ params += p_fn() * 3
79
133
  h = h**2 # because perturbation already multiplied by h
80
- return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
134
+ return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
135
+
81
136
 
82
137
  _RFD_FUNCS = {
138
+ "forward": _rforward2,
83
139
  "forward2": _rforward2,
140
+ "backward": _rbackward2,
84
141
  "backward2": _rbackward2,
142
+ "central": _rcentral2,
85
143
  "central2": _rcentral2,
144
+ "central3": _rcentral2,
86
145
  "forward3": _rforward3,
87
146
  "backward3": _rbackward3,
88
147
  "central4": _rcentral4,
148
+ "forward4": _rforward4,
149
+ "forward5": _rforward5,
150
+ "bspsa4": _bgspsa4,
89
151
  }
90
152
 
91
153
 
92
154
  class RandomizedFDM(GradApproximator):
155
+ """Gradient approximation via a randomized finite-difference method.
156
+
157
+ .. note::
158
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
159
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
160
+
161
+ Args:
162
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
163
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
164
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
165
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
166
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
167
+ beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
168
+ pre_generate (bool, optional):
169
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
170
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
171
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
172
+
173
+ Examples:
174
+ #### Simultaneous perturbation stochastic approximation (SPSA) method
175
+
176
+ SPSA is randomized finite differnce with rademacher distribution and central formula.
177
+
178
+ .. code-block:: python
179
+
180
+ spsa = tz.Modular(
181
+ model.parameters(),
182
+ tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
183
+ tz.m.LR(1e-2)
184
+ )
185
+
186
+ #### Random-direction stochastic approximation (RDSA) method
187
+
188
+ RDSA is randomized finite differnce with usually gaussian distribution and central formula.
189
+
190
+ .. code-block:: python
191
+
192
+ rdsa = tz.Modular(
193
+ model.parameters(),
194
+ tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
195
+ tz.m.LR(1e-2)
196
+ )
197
+
198
+ #### RandomizedFDM with momentum
199
+
200
+ Momentum might help by reducing the variance of the estimated gradients.
201
+
202
+ .. code-block:: python
203
+
204
+ momentum_spsa = tz.Modular(
205
+ model.parameters(),
206
+ tz.m.RandomizedFDM(),
207
+ tz.m.HeavyBall(0.9),
208
+ tz.m.LR(1e-3)
209
+ )
210
+
211
+ #### Gaussian smoothing method
212
+
213
+ GS uses many gaussian samples with possibly a larger finite difference step size.
214
+
215
+ .. code-block:: python
216
+
217
+ gs = tz.Modular(
218
+ model.parameters(),
219
+ tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
220
+ tz.m.NewtonCG(hvp_method="forward"),
221
+ tz.m.Backtracking()
222
+ )
223
+
224
+ #### SPSA-NewtonCG
225
+
226
+ NewtonCG with hessian-vector product estimated via gradient difference
227
+ calls closure multiple times per step. If each closure call estimates gradients
228
+ with different perturbations, NewtonCG is unable to produce useful directions.
229
+
230
+ By setting pre_generate to True, perturbations are generated once before each step,
231
+ and each closure call estimates gradients using the same pre-generated perturbations.
232
+ This way closure-based algorithms are able to use gradients estimated in a consistent way.
233
+
234
+ .. code-block:: python
235
+
236
+ opt = tz.Modular(
237
+ model.parameters(),
238
+ tz.m.RandomizedFDM(n_samples=10),
239
+ tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
240
+ tz.m.Backtracking()
241
+ )
242
+
243
+ #### SPSA-BFGS
244
+
245
+ L-BFGS uses a memory of past parameter and gradient differences. If past gradients
246
+ were estimated with different perturbations, L-BFGS directions will be useless.
247
+
248
+ To alleviate this momentum can be added to random perturbations to make sure they only
249
+ change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
250
+ The disadvantage is that the subspace the algorithm is able to explore changes slowly.
251
+
252
+ Additionally we will reset BFGS memory every 100 steps to remove influence from old gradient estimates.
253
+
254
+ .. code-block:: python
255
+
256
+ opt = tz.Modular(
257
+ model.parameters(),
258
+ tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99),
259
+ tz.m.BFGS(reset_interval=100),
260
+ tz.m.Backtracking()
261
+ )
262
+ """
93
263
  PRE_MULTIPLY_BY_H = True
94
264
  def __init__(
95
265
  self,
96
266
  h: float = 1e-3,
97
267
  n_samples: int = 1,
98
- formula: _FD_Formula = "central2",
268
+ formula: _FD_Formula = "central",
99
269
  distribution: Distributions = "rademacher",
100
270
  beta: float = 0,
101
271
  pre_generate = True,
102
- target: GradTarget = "closure",
103
272
  seed: int | None | torch.Generator = None,
273
+ target: GradTarget = "closure",
104
274
  ):
105
275
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
106
276
  super().__init__(defaults, target=target)
@@ -118,16 +288,16 @@ class RandomizedFDM(GradApproximator):
118
288
  else: self.global_state['generator'] = None
119
289
  return self.global_state['generator']
120
290
 
121
- def pre_step(self, vars):
122
- h, beta = self.get_settings('h', 'beta', params=vars.params)
123
- settings = self.settings[vars.params[0]]
291
+ def pre_step(self, var):
292
+ h, beta = self.get_settings(var.params, 'h', 'beta')
293
+ settings = self.settings[var.params[0]]
124
294
  n_samples = settings['n_samples']
125
295
  distribution = settings['distribution']
126
296
  pre_generate = settings['pre_generate']
127
297
 
128
298
  if pre_generate:
129
- params = TensorList(vars.params)
130
- generator = self._get_generator(settings['seed'], vars.params)
299
+ params = TensorList(var.params)
300
+ generator = self._get_generator(settings['seed'], var.params)
131
301
  perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
132
302
 
133
303
  if self.PRE_MULTIPLY_BY_H:
@@ -152,11 +322,12 @@ class RandomizedFDM(GradApproximator):
152
322
  torch._foreach_lerp_(cur_flat, new_flat, betas)
153
323
 
154
324
  @torch.no_grad
155
- def approximate(self, closure, params, loss, vars):
325
+ def approximate(self, closure, params, loss):
156
326
  params = TensorList(params)
327
+ orig_params = params.clone() # store to avoid small changes due to float imprecision
157
328
  loss_approx = None
158
329
 
159
- h = self.get_settings('h', params=vars.params, cls=NumberList)
330
+ h = NumberList(self.settings[p]['h'] for p in params)
160
331
  settings = self.settings[params[0]]
161
332
  n_samples = settings['n_samples']
162
333
  fd_fn = _RFD_FUNCS[settings['formula']]
@@ -171,17 +342,64 @@ class RandomizedFDM(GradApproximator):
171
342
  if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
172
343
  else: prt = TensorList(prt)
173
344
 
174
- loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, v_0=loss)
345
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
175
346
  if grad is None: grad = prt * d
176
347
  else: grad += prt * d
177
348
 
349
+ params.set_(orig_params)
178
350
  assert grad is not None
179
351
  if n_samples > 1: grad.div_(n_samples)
180
352
  return grad, loss, loss_approx
181
353
 
182
- SPSA = RandomizedFDM
354
+ class SPSA(RandomizedFDM):
355
+ """
356
+ Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
357
+
358
+ .. note::
359
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
360
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
361
+
362
+
363
+ Args:
364
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
365
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
366
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
367
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
368
+ beta (float, optional):
369
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
370
+ pre_generate (bool, optional):
371
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
372
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
373
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
374
+
375
+ References:
376
+ Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
377
+ """
183
378
 
184
379
  class RDSA(RandomizedFDM):
380
+ """
381
+ Gradient approximation via Random-direction stochastic approximation (RDSA) method.
382
+
383
+ .. note::
384
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
385
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
386
+
387
+ Args:
388
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
389
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
390
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
391
+ distribution (Distributions, optional): distribution. Defaults to "gaussian".
392
+ beta (float, optional):
393
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
394
+ pre_generate (bool, optional):
395
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
396
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
397
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
398
+
399
+ References:
400
+ Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
401
+
402
+ """
185
403
  def __init__(
186
404
  self,
187
405
  h: float = 1e-3,
@@ -196,11 +414,34 @@ class RDSA(RandomizedFDM):
196
414
  super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
197
415
 
198
416
  class GaussianSmoothing(RandomizedFDM):
417
+ """
418
+ Gradient approximation via Gaussian smoothing method.
419
+
420
+ .. note::
421
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
422
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
423
+
424
+ Args:
425
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
426
+ n_samples (int, optional): number of random gradient samples. Defaults to 100.
427
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
428
+ distribution (Distributions, optional): distribution. Defaults to "gaussian".
429
+ beta (float, optional):
430
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
431
+ pre_generate (bool, optional):
432
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
433
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
434
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
435
+
436
+
437
+ References:
438
+ Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
439
+ """
199
440
  def __init__(
200
441
  self,
201
442
  h: float = 1e-2,
202
443
  n_samples: int = 100,
203
- formula: _FD_Formula = "central2",
444
+ formula: _FD_Formula = "forward2",
204
445
  distribution: Distributions = "gaussian",
205
446
  beta: float = 0,
206
447
  pre_generate = True,
@@ -210,8 +451,27 @@ class GaussianSmoothing(RandomizedFDM):
210
451
  super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
211
452
 
212
453
  class MeZO(GradApproximator):
454
+ """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
455
+
456
+ .. note::
457
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
458
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
459
+
460
+ Args:
461
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
462
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
463
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
464
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
465
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
466
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
467
+
468
+ References:
469
+ Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
470
+ """
471
+
213
472
  def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
214
473
  distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
474
+
215
475
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
216
476
  super().__init__(defaults, target=target)
217
477
 
@@ -220,29 +480,29 @@ class MeZO(GradApproximator):
220
480
  distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
221
481
  ).mul_(h)
222
482
 
223
- def pre_step(self, vars):
224
- h = self.get_settings('h', params=vars.params)
225
- settings = self.settings[vars.params[0]]
483
+ def pre_step(self, var):
484
+ h = NumberList(self.settings[p]['h'] for p in var.params)
485
+ settings = self.settings[var.params[0]]
226
486
  n_samples = settings['n_samples']
227
487
  distribution = settings['distribution']
228
488
 
229
- step = vars.current_step
489
+ step = var.current_step
230
490
 
231
491
  # create functions that generate a deterministic perturbation from seed based on current step
232
492
  prt_fns = []
233
493
  for i in range(n_samples):
234
494
 
235
- prt_fn = partial(self._seeded_perturbation, params=vars.params, distribution=distribution, seed=1_000_000*step + i, h=h)
495
+ prt_fn = partial(self._seeded_perturbation, params=var.params, distribution=distribution, seed=1_000_000*step + i, h=h)
236
496
  prt_fns.append(prt_fn)
237
497
 
238
498
  self.global_state['prt_fns'] = prt_fns
239
499
 
240
500
  @torch.no_grad
241
- def approximate(self, closure, params, loss, vars):
501
+ def approximate(self, closure, params, loss):
242
502
  params = TensorList(params)
243
503
  loss_approx = None
244
504
 
245
- h = self.get_settings('h', params=vars.params, cls=NumberList)
505
+ h = NumberList(self.settings[p]['h'] for p in params)
246
506
  settings = self.settings[params[0]]
247
507
  n_samples = settings['n_samples']
248
508
  fd_fn = _RFD_FUNCS[settings['formula']]
@@ -250,7 +510,7 @@ class MeZO(GradApproximator):
250
510
 
251
511
  grad = None
252
512
  for i in range(n_samples):
253
- loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, v_0=loss)
513
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
254
514
  if grad is None: grad = prt_fns[i]().mul_(d)
255
515
  else: grad += prt_fns[i]().mul_(d)
256
516
 
@@ -0,0 +1 @@
1
+ from .higher_order_newton import HigherOrderNewton