torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -3,94 +3,160 @@ from typing import Any
3
3
  from functools import partial
4
4
  import torch
5
5
 
6
- from ...utils import TensorList, Distributions, NumberList, generic_eq
6
+ from ...utils import TensorList, Distributions, NumberList
7
7
  from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
8
8
 
9
-
10
- def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
9
+ def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
11
10
  """p_fn is a function that returns the perturbation.
12
11
  It may return pre-generated one or generate one deterministically from a seed as in MeZO.
13
12
  Returned perturbation must be multiplied by `h`."""
14
- if v_0 is None: v_0 = closure(False)
13
+ if f_0 is None: f_0 = closure(False)
15
14
  params += p_fn()
16
- v_plus = closure(False)
15
+ f_1 = closure(False)
17
16
  params -= p_fn()
18
17
  h = h**2 # because perturbation already multiplied by h
19
- return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
18
+ return f_0, f_0, (f_1 - f_0) / h # (loss, loss_approx, grad)
20
19
 
21
- def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
22
- if v_0 is None: v_0 = closure(False)
20
+ def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
21
+ if f_0 is None: f_0 = closure(False)
23
22
  params -= p_fn()
24
- v_minus = closure(False)
23
+ f_m1 = closure(False)
25
24
  params += p_fn()
26
25
  h = h**2 # because perturbation already multiplied by h
27
- return v_0, v_0, (v_0 - v_minus) / h
26
+ return f_0, f_0, (f_0 - f_m1) / h
28
27
 
29
- def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: Any):
28
+ def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: Any):
30
29
  params += p_fn()
31
- v_plus = closure(False)
30
+ f_1 = closure(False)
32
31
 
33
32
  params -= p_fn() * 2
34
- v_minus = closure(False)
33
+ f_m1 = closure(False)
35
34
 
36
35
  params += p_fn()
37
36
  h = h**2 # because perturbation already multiplied by h
38
- return v_0, v_plus, (v_plus - v_minus) / (2 * h)
37
+ return f_0, f_1, (f_1 - f_m1) / (2 * h)
39
38
 
40
- def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
41
- if v_0 is None: v_0 = closure(False)
39
+ def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
40
+ if f_0 is None: f_0 = closure(False)
42
41
  params += p_fn()
43
- v_plus1 = closure(False)
42
+ f_1 = closure(False)
44
43
 
45
44
  params += p_fn()
46
- v_plus2 = closure(False)
45
+ f_2 = closure(False)
47
46
 
48
47
  params -= p_fn() * 2
49
48
  h = h**2 # because perturbation already multiplied by h
50
- return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
49
+ return f_0, f_0, (-3*f_0 + 4*f_1 - f_2) / (2 * h)
51
50
 
52
- def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
53
- if v_0 is None: v_0 = closure(False)
51
+ def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
52
+ if f_0 is None: f_0 = closure(False)
54
53
 
55
54
  params -= p_fn()
56
- v_minus1 = closure(False)
55
+ f_m1 = closure(False)
57
56
 
58
57
  params -= p_fn()
59
- v_minus2 = closure(False)
58
+ f_m2 = closure(False)
60
59
 
61
60
  params += p_fn() * 2
62
61
  h = h**2 # because perturbation already multiplied by h
63
- return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
62
+ return f_0, f_0, (f_m2 - 4*f_m1 + 3*f_0) / (2 * h)
64
63
 
65
- def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
64
+ def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
66
65
  params += p_fn()
67
- v_plus1 = closure(False)
66
+ f_1 = closure(False)
68
67
 
69
68
  params += p_fn()
70
- v_plus2 = closure(False)
69
+ f_2 = closure(False)
71
70
 
72
71
  params -= p_fn() * 3
73
- v_minus1 = closure(False)
72
+ f_m1 = closure(False)
74
73
 
75
74
  params -= p_fn()
76
- v_minus2 = closure(False)
75
+ f_m2 = closure(False)
77
76
 
78
77
  params += p_fn() * 2
79
78
  h = h**2 # because perturbation already multiplied by h
80
- return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
79
+ return f_0, f_1, (f_m2 - 8*f_m1 + 8*f_1 - f_2) / (12 * h)
80
+
81
+ # some good ones
82
+ # Pachalyl S. et al. Generalized simultaneous perturbation-based gradient search with reduced estimator bias //IEEE Transactions on Automatic Control. – 2025.
83
+ # Three measurements GSPSA is _rforward3
84
+ # Four measurements GSPSA
85
+ def _rforward4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
86
+ if f_0 is None: f_0 = closure(False)
87
+ params += p_fn()
88
+ f_1 = closure(False)
89
+
90
+ params += p_fn()
91
+ f_2 = closure(False)
92
+
93
+ params += p_fn()
94
+ f_3 = closure(False)
95
+
96
+ params -= p_fn() * 3
97
+ h = h**2 # because perturbation already multiplied by h
98
+ return f_0, f_0, (2*f_3 - 9*f_2 + 18*f_1 - 11*f_0) / (6 * h)
99
+
100
+ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
101
+ if f_0 is None: f_0 = closure(False)
102
+ params += p_fn()
103
+ f_1 = closure(False)
104
+
105
+ params += p_fn()
106
+ f_2 = closure(False)
107
+
108
+ params += p_fn()
109
+ f_3 = closure(False)
110
+
111
+ params += p_fn()
112
+ f_4 = closure(False)
113
+
114
+ params -= p_fn() * 4
115
+ h = h**2 # because perturbation already multiplied by h
116
+ return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
117
+
118
+ # # another central4
119
+ # def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
+ # params += p_fn()
121
+ # f_1 = closure(False)
122
+
123
+ # params += p_fn() * 2
124
+ # f_3 = closure(False)
125
+
126
+ # params -= p_fn() * 4
127
+ # f_m1 = closure(False)
128
+
129
+ # params -= p_fn() * 2
130
+ # f_m3 = closure(False)
131
+
132
+ # params += p_fn() * 3
133
+ # h = h**2 # because perturbation already multiplied by h
134
+ # return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
135
+
81
136
 
82
- _RFD_FUNCS = {
137
+ _RFD_FUNCS: dict[_FD_Formula, Callable] = {
138
+ "forward": _rforward2,
83
139
  "forward2": _rforward2,
140
+ "backward": _rbackward2,
84
141
  "backward2": _rbackward2,
142
+ "central": _rcentral2,
85
143
  "central2": _rcentral2,
144
+ "central3": _rcentral2,
86
145
  "forward3": _rforward3,
87
146
  "backward3": _rbackward3,
88
147
  "central4": _rcentral4,
148
+ "forward4": _rforward4,
149
+ "forward5": _rforward5,
150
+ # "bspsa4": _bgspsa4,
89
151
  }
90
152
 
91
153
 
92
154
  class RandomizedFDM(GradApproximator):
93
- """_summary_
155
+ """Gradient approximation via a randomized finite-difference method.
156
+
157
+ Note:
158
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
159
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
94
160
 
95
161
  Args:
96
162
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
@@ -98,17 +164,109 @@ class RandomizedFDM(GradApproximator):
98
164
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
99
165
  distribution (Distributions, optional): distribution. Defaults to "rademacher".
100
166
  If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
167
+ beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
101
168
  pre_generate (bool, optional):
102
169
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
103
170
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
104
171
  target (GradTarget, optional): what to set on var. Defaults to "closure".
172
+
173
+ Examples:
174
+ #### Simultaneous perturbation stochastic approximation (SPSA) method
175
+
176
+ SPSA is randomized finite differnce with rademacher distribution and central formula.
177
+ ```py
178
+ spsa = tz.Modular(
179
+ model.parameters(),
180
+ tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
181
+ tz.m.LR(1e-2)
182
+ )
183
+ ```
184
+
185
+ #### Random-direction stochastic approximation (RDSA) method
186
+
187
+ RDSA is randomized finite differnce with usually gaussian distribution and central formula.
188
+
189
+ ```
190
+ rdsa = tz.Modular(
191
+ model.parameters(),
192
+ tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
193
+ tz.m.LR(1e-2)
194
+ )
195
+ ```
196
+
197
+ #### RandomizedFDM with momentum
198
+
199
+ Momentum might help by reducing the variance of the estimated gradients.
200
+
201
+ ```
202
+ momentum_spsa = tz.Modular(
203
+ model.parameters(),
204
+ tz.m.RandomizedFDM(),
205
+ tz.m.HeavyBall(0.9),
206
+ tz.m.LR(1e-3)
207
+ )
208
+ ```
209
+
210
+ #### Gaussian smoothing method
211
+
212
+ GS uses many gaussian samples with possibly a larger finite difference step size.
213
+
214
+ ```
215
+ gs = tz.Modular(
216
+ model.parameters(),
217
+ tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
218
+ tz.m.NewtonCG(hvp_method="forward"),
219
+ tz.m.Backtracking()
220
+ )
221
+ ```
222
+
223
+ #### SPSA-NewtonCG
224
+
225
+ NewtonCG with hessian-vector product estimated via gradient difference
226
+ calls closure multiple times per step. If each closure call estimates gradients
227
+ with different perturbations, NewtonCG is unable to produce useful directions.
228
+
229
+ By setting pre_generate to True, perturbations are generated once before each step,
230
+ and each closure call estimates gradients using the same pre-generated perturbations.
231
+ This way closure-based algorithms are able to use gradients estimated in a consistent way.
232
+
233
+ ```
234
+ opt = tz.Modular(
235
+ model.parameters(),
236
+ tz.m.RandomizedFDM(n_samples=10),
237
+ tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
238
+ tz.m.Backtracking()
239
+ )
240
+ ```
241
+
242
+ #### SPSA-LBFGS
243
+
244
+ LBFGS uses a memory of past parameter and gradient differences. If past gradients
245
+ were estimated with different perturbations, LBFGS directions will be useless.
246
+
247
+ To alleviate this momentum can be added to random perturbations to make sure they only
248
+ change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
249
+ The disadvantage is that the subspace the algorithm is able to explore changes slowly.
250
+
251
+ Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
252
+
253
+ ```
254
+ opt = tz.Modular(
255
+ bench.parameters(),
256
+ tz.m.ResetEvery(
257
+ [tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
258
+ steps = 100,
259
+ ),
260
+ tz.m.Backtracking()
261
+ )
262
+ ```
105
263
  """
106
264
  PRE_MULTIPLY_BY_H = True
107
265
  def __init__(
108
266
  self,
109
267
  h: float = 1e-3,
110
268
  n_samples: int = 1,
111
- formula: _FD_Formula = "central2",
269
+ formula: _FD_Formula = "central",
112
270
  distribution: Distributions = "rademacher",
113
271
  beta: float = 0,
114
272
  pre_generate = True,
@@ -123,6 +281,7 @@ class RandomizedFDM(GradApproximator):
123
281
  generator = self.global_state.get('generator', None) # avoid resetting generator
124
282
  self.global_state.clear()
125
283
  if generator is not None: self.global_state['generator'] = generator
284
+ for c in self.children.values(): c.reset()
126
285
 
127
286
  def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
128
287
  if 'generator' not in self.global_state:
@@ -133,15 +292,15 @@ class RandomizedFDM(GradApproximator):
133
292
 
134
293
  def pre_step(self, var):
135
294
  h, beta = self.get_settings(var.params, 'h', 'beta')
136
- settings = self.settings[var.params[0]]
137
- n_samples = settings['n_samples']
138
- distribution = settings['distribution']
139
- pre_generate = settings['pre_generate']
295
+
296
+ n_samples = self.defaults['n_samples']
297
+ distribution = self.defaults['distribution']
298
+ pre_generate = self.defaults['pre_generate']
140
299
 
141
300
  if pre_generate:
142
301
  params = TensorList(var.params)
143
- generator = self._get_generator(settings['seed'], var.params)
144
- perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
302
+ generator = self._get_generator(self.defaults['seed'], var.params)
303
+ perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
145
304
 
146
305
  if self.PRE_MULTIPLY_BY_H:
147
306
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
@@ -165,8 +324,9 @@ class RandomizedFDM(GradApproximator):
165
324
  torch._foreach_lerp_(cur_flat, new_flat, betas)
166
325
 
167
326
  @torch.no_grad
168
- def approximate(self, closure, params, loss, var):
327
+ def approximate(self, closure, params, loss):
169
328
  params = TensorList(params)
329
+ orig_params = params.clone() # store to avoid small changes due to float imprecision
170
330
  loss_approx = None
171
331
 
172
332
  h = NumberList(self.settings[p]['h'] for p in params)
@@ -181,20 +341,84 @@ class RandomizedFDM(GradApproximator):
181
341
  grad = None
182
342
  for i in range(n_samples):
183
343
  prt = perturbations[i]
184
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
344
+
345
+ if prt[0] is None:
346
+ prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
347
+
185
348
  else: prt = TensorList(prt)
186
349
 
187
- loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, v_0=loss)
350
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
351
+ # here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
352
+
353
+ # support for per-sample values which gives better estimate
354
+ if d[0].numel() > 1: d = d.map(torch.mean)
355
+
188
356
  if grad is None: grad = prt * d
189
357
  else: grad += prt * d
190
358
 
359
+ params.set_(orig_params)
191
360
  assert grad is not None
192
361
  if n_samples > 1: grad.div_(n_samples)
362
+
363
+ # mean if got per-sample values
364
+ if loss is not None:
365
+ if loss.numel() > 1:
366
+ loss = loss.mean()
367
+
368
+ if loss_approx is not None:
369
+ if loss_approx.numel() > 1:
370
+ loss_approx = loss_approx.mean()
371
+
193
372
  return grad, loss, loss_approx
194
373
 
195
- SPSA = RandomizedFDM
374
+ class SPSA(RandomizedFDM):
375
+ """
376
+ Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
377
+
378
+ Note:
379
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
380
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
381
+
382
+ Args:
383
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
384
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
385
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
386
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
387
+ beta (float, optional):
388
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
389
+ pre_generate (bool, optional):
390
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
391
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
392
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
393
+
394
+ References:
395
+ Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
396
+ """
196
397
 
197
398
  class RDSA(RandomizedFDM):
399
+ """
400
+ Gradient approximation via Random-direction stochastic approximation (RDSA) method.
401
+
402
+ Note:
403
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
404
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
405
+
406
+ Args:
407
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
408
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
409
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
410
+ distribution (Distributions, optional): distribution. Defaults to "gaussian".
411
+ beta (float, optional):
412
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
413
+ pre_generate (bool, optional):
414
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
415
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
416
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
417
+
418
+ References:
419
+ Chen, Y. (2021). Theoretical study and comparison of SPSA and RDSA algorithms. arXiv preprint arXiv:2107.12771. https://arxiv.org/abs/2107.12771
420
+
421
+ """
198
422
  def __init__(
199
423
  self,
200
424
  h: float = 1e-3,
@@ -209,11 +433,34 @@ class RDSA(RandomizedFDM):
209
433
  super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
210
434
 
211
435
  class GaussianSmoothing(RandomizedFDM):
436
+ """
437
+ Gradient approximation via Gaussian smoothing method.
438
+
439
+ Note:
440
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
441
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
442
+
443
+ Args:
444
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-2.
445
+ n_samples (int, optional): number of random gradient samples. Defaults to 100.
446
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
447
+ distribution (Distributions, optional): distribution. Defaults to "gaussian".
448
+ beta (float, optional):
449
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
450
+ pre_generate (bool, optional):
451
+ whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
452
+ seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
453
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
454
+
455
+
456
+ References:
457
+ Yurii Nesterov, Vladimir Spokoiny. (2015). Random Gradient-Free Minimization of Convex Functions. https://gwern.net/doc/math/2015-nesterov.pdf
458
+ """
212
459
  def __init__(
213
460
  self,
214
461
  h: float = 1e-2,
215
462
  n_samples: int = 100,
216
- formula: _FD_Formula = "central2",
463
+ formula: _FD_Formula = "forward2",
217
464
  distribution: Distributions = "gaussian",
218
465
  beta: float = 0,
219
466
  pre_generate = True,
@@ -223,21 +470,43 @@ class GaussianSmoothing(RandomizedFDM):
223
470
  super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
224
471
 
225
472
  class MeZO(GradApproximator):
473
+ """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
474
+
475
+ Note:
476
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
477
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
478
+
479
+ Args:
480
+ h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
481
+ n_samples (int, optional): number of random gradient samples. Defaults to 1.
482
+ formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
483
+ distribution (Distributions, optional): distribution. Defaults to "rademacher".
484
+ If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
485
+ target (GradTarget, optional): what to set on var. Defaults to "closure".
486
+
487
+ References:
488
+ Malladi, S., Gao, T., Nichani, E., Damian, A., Lee, J. D., Chen, D., & Arora, S. (2023). Fine-tuning language models with just forward passes. Advances in Neural Information Processing Systems, 36, 53038-53075. https://arxiv.org/abs/2305.17333
489
+ """
490
+
226
491
  def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
227
492
  distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
493
+
228
494
  defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
229
495
  super().__init__(defaults, target=target)
230
496
 
231
497
  def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
232
- return TensorList(params).sample_like(
233
- distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
234
- ).mul_(h)
498
+ prt = TensorList(params).sample_like(
499
+ distribution=distribution,
500
+ variance=h,
501
+ generator=torch.Generator(params[0].device).manual_seed(seed)
502
+ )
503
+ return prt
235
504
 
236
505
  def pre_step(self, var):
237
506
  h = NumberList(self.settings[p]['h'] for p in var.params)
238
- settings = self.settings[var.params[0]]
239
- n_samples = settings['n_samples']
240
- distribution = settings['distribution']
507
+
508
+ n_samples = self.defaults['n_samples']
509
+ distribution = self.defaults['distribution']
241
510
 
242
511
  step = var.current_step
243
512
 
@@ -251,7 +520,7 @@ class MeZO(GradApproximator):
251
520
  self.global_state['prt_fns'] = prt_fns
252
521
 
253
522
  @torch.no_grad
254
- def approximate(self, closure, params, loss, var):
523
+ def approximate(self, closure, params, loss):
255
524
  params = TensorList(params)
256
525
  loss_approx = None
257
526
 
@@ -263,7 +532,7 @@ class MeZO(GradApproximator):
263
532
 
264
533
  grad = None
265
534
  for i in range(n_samples):
266
- loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, v_0=loss)
535
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, f_0=loss)
267
536
  if grad is None: grad = prt_fns[i]().mul_(d)
268
537
  else: grad += prt_fns[i]().mul_(d)
269
538
 
@@ -1 +1 @@
1
- from .higher_order_newton import HigherOrderNewton
1
+ from .higher_order_newton import HigherOrderNewton