torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -115,26 +115,26 @@ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[
115
115
  h = h**2 # because perturbation already multiplied by h
116
116
  return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
117
117
 
118
- # another central4
119
- def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
- params += p_fn()
121
- f_1 = closure(False)
118
+ # # another central4
119
+ # def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
+ # params += p_fn()
121
+ # f_1 = closure(False)
122
122
 
123
- params += p_fn() * 2
124
- f_3 = closure(False)
123
+ # params += p_fn() * 2
124
+ # f_3 = closure(False)
125
125
 
126
- params -= p_fn() * 4
127
- f_m1 = closure(False)
126
+ # params -= p_fn() * 4
127
+ # f_m1 = closure(False)
128
128
 
129
- params -= p_fn() * 2
130
- f_m3 = closure(False)
129
+ # params -= p_fn() * 2
130
+ # f_m3 = closure(False)
131
131
 
132
- params += p_fn() * 3
133
- h = h**2 # because perturbation already multiplied by h
134
- return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
132
+ # params += p_fn() * 3
133
+ # h = h**2 # because perturbation already multiplied by h
134
+ # return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
135
135
 
136
136
 
137
- _RFD_FUNCS = {
137
+ _RFD_FUNCS: dict[_FD_Formula, Callable] = {
138
138
  "forward": _rforward2,
139
139
  "forward2": _rforward2,
140
140
  "backward": _rbackward2,
@@ -147,14 +147,14 @@ _RFD_FUNCS = {
147
147
  "central4": _rcentral4,
148
148
  "forward4": _rforward4,
149
149
  "forward5": _rforward5,
150
- "bspsa4": _bgspsa4,
150
+ # "bspsa4": _bgspsa4,
151
151
  }
152
152
 
153
153
 
154
154
  class RandomizedFDM(GradApproximator):
155
155
  """Gradient approximation via a randomized finite-difference method.
156
156
 
157
- .. note::
157
+ Note:
158
158
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
159
159
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
160
160
 
@@ -164,101 +164,57 @@ class RandomizedFDM(GradApproximator):
164
164
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
165
165
  distribution (Distributions, optional): distribution. Defaults to "rademacher".
166
166
  If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
167
- beta (float, optional): optinal momentum for generated perturbations. Defaults to 1e-3.
168
167
  pre_generate (bool, optional):
169
168
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
170
169
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
171
170
  target (GradTarget, optional): what to set on var. Defaults to "closure".
172
171
 
173
172
  Examples:
174
- #### Simultaneous perturbation stochastic approximation (SPSA) method
175
-
176
- SPSA is randomized finite differnce with rademacher distribution and central formula.
177
-
178
- .. code-block:: python
179
-
180
- spsa = tz.Modular(
181
- model.parameters(),
182
- tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
183
- tz.m.LR(1e-2)
184
- )
185
-
186
- #### Random-direction stochastic approximation (RDSA) method
187
-
188
- RDSA is randomized finite differnce with usually gaussian distribution and central formula.
189
-
190
- .. code-block:: python
191
-
192
- rdsa = tz.Modular(
193
- model.parameters(),
194
- tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
195
- tz.m.LR(1e-2)
196
- )
197
-
198
- #### RandomizedFDM with momentum
199
-
200
- Momentum might help by reducing the variance of the estimated gradients.
201
-
202
- .. code-block:: python
203
-
204
- momentum_spsa = tz.Modular(
205
- model.parameters(),
206
- tz.m.RandomizedFDM(),
207
- tz.m.HeavyBall(0.9),
208
- tz.m.LR(1e-3)
209
- )
210
-
211
- #### Gaussian smoothing method
212
-
213
- GS uses many gaussian samples with possibly a larger finite difference step size.
214
-
215
- .. code-block:: python
216
-
217
- gs = tz.Modular(
218
- model.parameters(),
219
- tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
220
- tz.m.NewtonCG(hvp_method="forward"),
221
- tz.m.Backtracking()
222
- )
223
-
224
- #### SPSA-NewtonCG
225
-
226
- NewtonCG with hessian-vector product estimated via gradient difference
227
- calls closure multiple times per step. If each closure call estimates gradients
228
- with different perturbations, NewtonCG is unable to produce useful directions.
229
-
230
- By setting pre_generate to True, perturbations are generated once before each step,
231
- and each closure call estimates gradients using the same pre-generated perturbations.
232
- This way closure-based algorithms are able to use gradients estimated in a consistent way.
233
-
234
- .. code-block:: python
235
-
236
- opt = tz.Modular(
237
- model.parameters(),
238
- tz.m.RandomizedFDM(n_samples=10),
239
- tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
240
- tz.m.Backtracking()
241
- )
242
-
243
- #### SPSA-BFGS
244
-
245
- L-BFGS uses a memory of past parameter and gradient differences. If past gradients
246
- were estimated with different perturbations, L-BFGS directions will be useless.
247
-
248
- To alleviate this momentum can be added to random perturbations to make sure they only
249
- change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
250
- The disadvantage is that the subspace the algorithm is able to explore changes slowly.
251
-
252
- Additionally we will reset BFGS memory every 100 steps to remove influence from old gradient estimates.
253
-
254
- .. code-block:: python
255
-
256
- opt = tz.Modular(
257
- model.parameters(),
258
- tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99),
259
- tz.m.BFGS(reset_interval=100),
260
- tz.m.Backtracking()
261
- )
173
+ #### Simultaneous perturbation stochastic approximation (SPSA) method
174
+
175
+ SPSA is randomized FDM with rademacher distribution and central formula.
176
+ ```py
177
+ spsa = tz.Modular(
178
+ model.parameters(),
179
+ tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
180
+ tz.m.LR(1e-2)
181
+ )
182
+ ```
183
+
184
+ #### Random-direction stochastic approximation (RDSA) method
185
+
186
+ RDSA is randomized FDM with usually gaussian distribution and central formula.
187
+ ```
188
+ rdsa = tz.Modular(
189
+ model.parameters(),
190
+ tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
191
+ tz.m.LR(1e-2)
192
+ )
193
+ ```
194
+
195
+ #### Gaussian smoothing method
196
+
197
+ GS uses many gaussian samples with possibly a larger finite difference step size.
198
+ ```
199
+ gs = tz.Modular(
200
+ model.parameters(),
201
+ tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
202
+ tz.m.NewtonCG(hvp_method="forward"),
203
+ tz.m.Backtracking()
204
+ )
205
+ ```
206
+
207
+ #### RandomizedFDM with momentum
208
+
209
+ Momentum might help by reducing the variance of the estimated gradients.
210
+ ```
211
+ momentum_spsa = tz.Modular(
212
+ model.parameters(),
213
+ tz.m.RandomizedFDM(),
214
+ tz.m.HeavyBall(0.9),
215
+ tz.m.LR(1e-3)
216
+ )
217
+ ```
262
218
  """
263
219
  PRE_MULTIPLY_BY_H = True
264
220
  def __init__(
@@ -267,106 +223,92 @@ class RandomizedFDM(GradApproximator):
267
223
  n_samples: int = 1,
268
224
  formula: _FD_Formula = "central",
269
225
  distribution: Distributions = "rademacher",
270
- beta: float = 0,
271
226
  pre_generate = True,
272
227
  seed: int | None | torch.Generator = None,
273
228
  target: GradTarget = "closure",
274
229
  ):
275
- defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
230
+ defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, pre_generate=pre_generate, seed=seed)
276
231
  super().__init__(defaults, target=target)
277
232
 
278
- def reset(self):
279
- self.state.clear()
280
- generator = self.global_state.get('generator', None) # avoid resetting generator
281
- self.global_state.clear()
282
- if generator is not None: self.global_state['generator'] = generator
283
-
284
- def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
285
- if 'generator' not in self.global_state:
286
- if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
287
- elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
288
- else: self.global_state['generator'] = None
289
- return self.global_state['generator']
290
233
 
291
234
  def pre_step(self, var):
292
- h, beta = self.get_settings(var.params, 'h', 'beta')
293
- settings = self.settings[var.params[0]]
294
- n_samples = settings['n_samples']
295
- distribution = settings['distribution']
296
- pre_generate = settings['pre_generate']
235
+ h = self.get_settings(var.params, 'h')
236
+ pre_generate = self.defaults['pre_generate']
297
237
 
298
238
  if pre_generate:
239
+ n_samples = self.defaults['n_samples']
240
+ distribution = self.defaults['distribution']
241
+
299
242
  params = TensorList(var.params)
300
- generator = self._get_generator(settings['seed'], var.params)
301
- perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
243
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
244
+ perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
302
245
 
246
+ # this is false for ForwardGradient where h isn't used and it subclasses this
303
247
  if self.PRE_MULTIPLY_BY_H:
304
248
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
305
249
 
306
- if all(i==0 for i in beta):
307
- # just use pre-generated perturbations
308
- for param, prt in zip(params, zip(*perturbations)):
309
- self.state[param]['perturbations'] = prt
310
-
311
- else:
312
- # lerp old and new perturbations. This makes the subspace change gradually
313
- # which in theory might improve algorithms with history
314
- for i,p in enumerate(params):
315
- state = self.state[p]
316
- if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
317
-
318
- cur = [self.state[p]['perturbations'][:n_samples] for p in params]
319
- cur_flat = [p for l in cur for p in l]
320
- new_flat = [p for l in zip(*perturbations) for p in l]
321
- betas = [1-v for b in beta for v in [b]*n_samples]
322
- torch._foreach_lerp_(cur_flat, new_flat, betas)
250
+ for param, prt in zip(params, zip(*perturbations)):
251
+ self.state[param]['perturbations'] = prt
323
252
 
324
253
  @torch.no_grad
325
254
  def approximate(self, closure, params, loss):
326
255
  params = TensorList(params)
327
- orig_params = params.clone() # store to avoid small changes due to float imprecision
328
256
  loss_approx = None
329
257
 
330
258
  h = NumberList(self.settings[p]['h'] for p in params)
331
- settings = self.settings[params[0]]
332
- n_samples = settings['n_samples']
333
- fd_fn = _RFD_FUNCS[settings['formula']]
259
+ n_samples = self.defaults['n_samples']
260
+ distribution = self.defaults['distribution']
261
+ fd_fn = _RFD_FUNCS[self.defaults['formula']]
262
+
334
263
  default = [None]*n_samples
335
264
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
336
- distribution = settings['distribution']
337
- generator = self._get_generator(settings['seed'], params)
265
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
338
266
 
339
267
  grad = None
340
268
  for i in range(n_samples):
341
269
  prt = perturbations[i]
342
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
270
+
271
+ if prt[0] is None:
272
+ prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
273
+
343
274
  else: prt = TensorList(prt)
344
275
 
345
276
  loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
277
+ # here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
278
+
279
+ # support for per-sample values which gives better estimate
280
+ if d[0].numel() > 1: d = d.map(torch.mean)
281
+
346
282
  if grad is None: grad = prt * d
347
283
  else: grad += prt * d
348
284
 
349
- params.set_(orig_params)
350
285
  assert grad is not None
351
286
  if n_samples > 1: grad.div_(n_samples)
287
+
288
+ # mean if got per-sample values
289
+ if loss is not None:
290
+ if loss.numel() > 1:
291
+ loss = loss.mean()
292
+
293
+ if loss_approx is not None:
294
+ if loss_approx.numel() > 1:
295
+ loss_approx = loss_approx.mean()
296
+
352
297
  return grad, loss, loss_approx
353
298
 
354
299
  class SPSA(RandomizedFDM):
355
300
  """
356
301
  Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
357
302
 
358
- .. note::
303
+ Note:
359
304
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
360
305
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
361
306
 
362
-
363
307
  Args:
364
308
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
365
309
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
366
310
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
367
311
  distribution (Distributions, optional): distribution. Defaults to "rademacher".
368
- beta (float, optional):
369
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
370
312
  pre_generate (bool, optional):
371
313
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
372
314
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -380,7 +322,7 @@ class RDSA(RandomizedFDM):
380
322
  """
381
323
  Gradient approximation via Random-direction stochastic approximation (RDSA) method.
382
324
 
383
- .. note::
325
+ Note:
384
326
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
385
327
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
386
328
 
@@ -389,8 +331,6 @@ class RDSA(RandomizedFDM):
389
331
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
390
332
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
391
333
  distribution (Distributions, optional): distribution. Defaults to "gaussian".
392
- beta (float, optional):
393
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
394
334
  pre_generate (bool, optional):
395
335
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
396
336
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -406,18 +346,17 @@ class RDSA(RandomizedFDM):
406
346
  n_samples: int = 1,
407
347
  formula: _FD_Formula = "central2",
408
348
  distribution: Distributions = "gaussian",
409
- beta: float = 0,
410
349
  pre_generate = True,
411
350
  target: GradTarget = "closure",
412
351
  seed: int | None | torch.Generator = None,
413
352
  ):
414
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
353
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
415
354
 
416
355
  class GaussianSmoothing(RandomizedFDM):
417
356
  """
418
357
  Gradient approximation via Gaussian smoothing method.
419
358
 
420
- .. note::
359
+ Note:
421
360
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
422
361
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
423
362
 
@@ -426,8 +365,6 @@ class GaussianSmoothing(RandomizedFDM):
426
365
  n_samples (int, optional): number of random gradient samples. Defaults to 100.
427
366
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'forward2'.
428
367
  distribution (Distributions, optional): distribution. Defaults to "gaussian".
429
- beta (float, optional):
430
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
431
368
  pre_generate (bool, optional):
432
369
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
433
370
  seed (int | None | torch.Generator, optional): Seed for random generator. Defaults to None.
@@ -443,17 +380,16 @@ class GaussianSmoothing(RandomizedFDM):
443
380
  n_samples: int = 100,
444
381
  formula: _FD_Formula = "forward2",
445
382
  distribution: Distributions = "gaussian",
446
- beta: float = 0,
447
383
  pre_generate = True,
448
384
  target: GradTarget = "closure",
449
385
  seed: int | None | torch.Generator = None,
450
386
  ):
451
- super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
387
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,pre_generate=pre_generate,target=target,seed=seed)
452
388
 
453
389
  class MeZO(GradApproximator):
454
390
  """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
455
391
 
456
- .. note::
392
+ Note:
457
393
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
458
394
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
459
395
 
@@ -476,15 +412,18 @@ class MeZO(GradApproximator):
476
412
  super().__init__(defaults, target=target)
477
413
 
478
414
  def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
479
- return TensorList(params).sample_like(
480
- distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
481
- ).mul_(h)
415
+ prt = TensorList(params).sample_like(
416
+ distribution=distribution,
417
+ variance=h,
418
+ generator=torch.Generator(params[0].device).manual_seed(seed)
419
+ )
420
+ return prt
482
421
 
483
422
  def pre_step(self, var):
484
423
  h = NumberList(self.settings[p]['h'] for p in var.params)
485
- settings = self.settings[var.params[0]]
486
- n_samples = settings['n_samples']
487
- distribution = settings['distribution']
424
+
425
+ n_samples = self.defaults['n_samples']
426
+ distribution = self.defaults['distribution']
488
427
 
489
428
  step = var.current_step
490
429
 
@@ -503,9 +442,9 @@ class MeZO(GradApproximator):
503
442
  loss_approx = None
504
443
 
505
444
  h = NumberList(self.settings[p]['h'] for p in params)
506
- settings = self.settings[params[0]]
507
- n_samples = settings['n_samples']
508
- fd_fn = _RFD_FUNCS[settings['formula']]
445
+ n_samples = self.defaults['n_samples']
446
+ fd_fn = _RFD_FUNCS[self.defaults['formula']]
447
+
509
448
  prt_fns = self.global_state['prt_fns']
510
449
 
511
450
  grad = None
@@ -1 +1 @@
1
- from .higher_order_newton import HigherOrderNewton
1
+ from .higher_order_newton import HigherOrderNewton
@@ -13,7 +13,7 @@ import torch
13
13
  from ...core import Chainable, Module, apply_transform
14
14
  from ...utils import TensorList, vec_to_tensors, vec_to_tensors_
15
15
  from ...utils.derivatives import (
16
- hessian_list_to_mat,
16
+ flatten_jacobian,
17
17
  jacobian_wrt,
18
18
  )
19
19
 
@@ -148,21 +148,16 @@ class HigherOrderNewton(Module):
148
148
  """A basic arbitrary order newton's method with optional trust region and proximal penalty.
149
149
 
150
150
  This constructs an nth order taylor approximation via autograd and minimizes it with
151
- scipy.optimize.minimize trust region newton solvers with optional proximal penalty.
151
+ ``scipy.optimize.minimize`` trust region newton solvers with optional proximal penalty.
152
152
 
153
- .. note::
154
- In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the :code:`inner` argument if you wish to apply Newton preconditioning to another module's output.
153
+ The hessian of taylor approximation is easier to evaluate, plus it can be evaluated in a batched mode,
154
+ so it can be more efficient in very specific instances.
155
155
 
156
- .. note::
157
- This module requires the a closure passed to the optimizer step,
158
- as it needs to re-evaluate the loss and gradients for calculating higher order derivatives.
159
- The closure must accept a ``backward`` argument (refer to documentation).
160
-
161
- .. warning::
162
- this uses roughly O(N^order) memory and solving the subproblem can be very expensive.
163
-
164
- .. warning::
165
- "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
156
+ Notes:
157
+ - In most cases HigherOrderNewton should be the first module in the chain because it relies on extra autograd. Use the ``inner`` argument if you wish to apply Newton preconditioning to another module's output.
158
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating higher order derivatives. The closure must accept a ``backward`` argument (refer to documentation).
159
+ - this uses roughly O(N^order) memory and solving the subproblem is very expensive.
160
+ - "none" and "proximal" trust methods may generate subproblems that have no minima, causing divergence.
166
161
 
167
162
  Args:
168
163
 
@@ -178,7 +173,7 @@ class HigherOrderNewton(Module):
178
173
  increase (float, optional): trust region multiplier on good steps. Defaults to 1.5.
179
174
  decrease (float, optional): trust region multiplier on bad steps. Defaults to 0.75.
180
175
  trust_init (float | None, optional):
181
- initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on :code:`"proximal"`. Defaults to None.
176
+ initial trust region size. If none, defaults to 1 on :code:`trust_method="bounds"` and 0.1 on ``"proximal"``. Defaults to None.
182
177
  trust_tol (float, optional):
183
178
  Maximum ratio of expected loss reduction to actual reduction for trust region increase.
184
179
  Should 1 or higer. Defaults to 2.
@@ -191,11 +186,14 @@ class HigherOrderNewton(Module):
191
186
  self,
192
187
  order: int = 4,
193
188
  trust_method: Literal['bounds', 'proximal', 'none'] | None = 'bounds',
194
- nplus: float = 2,
189
+ nplus: float = 3.5,
195
190
  nminus: float = 0.25,
191
+ rho_good: float = 0.99,
192
+ rho_bad: float = 1e-4,
196
193
  init: float | None = None,
197
194
  eta: float = 1e-6,
198
195
  max_attempts = 10,
196
+ boundary_tol: float = 1e-2,
199
197
  de_iters: int | None = None,
200
198
  vectorize: bool = True,
201
199
  ):
@@ -203,7 +201,7 @@ class HigherOrderNewton(Module):
203
201
  if trust_method == 'bounds': init = 1
204
202
  else: init = 0.1
205
203
 
206
- defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts)
204
+ defaults = dict(order=order, trust_method=trust_method, nplus=nplus, nminus=nminus, eta=eta, init=init, vectorize=vectorize, de_iters=de_iters, max_attempts=max_attempts, boundary_tol=boundary_tol, rho_good=rho_good, rho_bad=rho_bad)
207
205
  super().__init__(defaults)
208
206
 
209
207
  @torch.no_grad
@@ -222,6 +220,9 @@ class HigherOrderNewton(Module):
222
220
  de_iters = settings['de_iters']
223
221
  max_attempts = settings['max_attempts']
224
222
  vectorize = settings['vectorize']
223
+ boundary_tol = settings['boundary_tol']
224
+ rho_good = settings['rho_good']
225
+ rho_bad = settings['rho_bad']
225
226
 
226
227
  # ------------------------ calculate grad and hessian ------------------------ #
227
228
  with torch.enable_grad():
@@ -241,7 +242,7 @@ class HigherOrderNewton(Module):
241
242
  T_list = jacobian_wrt([T], params, create_graph=not is_last, batched=vectorize)
242
243
  with torch.no_grad() if is_last else nullcontext():
243
244
  # the shape is (ndim, ) * order
244
- T = hessian_list_to_mat(T_list).view(n, n, *T.shape[1:])
245
+ T = flatten_jacobian(T_list).view(n, n, *T.shape[1:])
245
246
  derivatives.append(T)
246
247
 
247
248
  x0 = torch.cat([p.ravel() for p in params])
@@ -254,8 +255,13 @@ class HigherOrderNewton(Module):
254
255
 
255
256
  # load trust region value
256
257
  trust_value = self.global_state.get('trust_region', init)
257
- if trust_value < 1e-8 or trust_value > 1e16: trust_value = self.global_state['trust_region'] = settings['init']
258
258
 
259
+ # make sure its not too small or too large
260
+ finfo = torch.finfo(x0.dtype)
261
+ if trust_value < finfo.tiny*2 or trust_value > finfo.max / (2*nplus):
262
+ trust_value = self.global_state['trust_region'] = settings['init']
263
+
264
+ # determine tr and prox values
259
265
  if trust_method is None: trust_method = 'none'
260
266
  else: trust_method = trust_method.lower()
261
267
 
@@ -297,13 +303,15 @@ class HigherOrderNewton(Module):
297
303
 
298
304
  rho = reduction / (max(pred_reduction, 1e-8))
299
305
  # failed step
300
- if rho < 0.25:
306
+ if rho < rho_bad:
301
307
  self.global_state['trust_region'] = trust_value * nminus
302
308
 
303
309
  # very good step
304
- elif rho > 0.75:
305
- diff = trust_value - (x0 - x_star).abs_()
306
- if (diff.amin() / trust_value) > 1e-4: # hits boundary
310
+ elif rho > rho_good:
311
+ step = (x_star - x0)
312
+ magn = torch.linalg.vector_norm(step) # pylint:disable=not-callable
313
+ if trust_method == 'proximal' or (trust_value - magn) / trust_value <= boundary_tol:
314
+ # close to boundary
307
315
  self.global_state['trust_region'] = trust_value * nplus
308
316
 
309
317
  # if the ratio is high enough then accept the proposed step
@@ -0,0 +1 @@
1
+ from .gn import SumOfSquares, GaussNewton