torchzero 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. torchzero/__init__.py +4 -0
  2. torchzero/core/__init__.py +13 -0
  3. torchzero/core/module.py +471 -0
  4. torchzero/core/tensorlist_optimizer.py +219 -0
  5. torchzero/modules/__init__.py +21 -0
  6. torchzero/modules/adaptive/__init__.py +4 -0
  7. torchzero/modules/adaptive/adaptive.py +192 -0
  8. torchzero/modules/experimental/__init__.py +19 -0
  9. torchzero/modules/experimental/experimental.py +294 -0
  10. torchzero/modules/experimental/quad_interp.py +104 -0
  11. torchzero/modules/experimental/subspace.py +259 -0
  12. torchzero/modules/gradient_approximation/__init__.py +7 -0
  13. torchzero/modules/gradient_approximation/_fd_formulas.py +3 -0
  14. torchzero/modules/gradient_approximation/base_approximator.py +110 -0
  15. torchzero/modules/gradient_approximation/fdm.py +125 -0
  16. torchzero/modules/gradient_approximation/forward_gradient.py +163 -0
  17. torchzero/modules/gradient_approximation/newton_fdm.py +198 -0
  18. torchzero/modules/gradient_approximation/rfdm.py +125 -0
  19. torchzero/modules/line_search/__init__.py +30 -0
  20. torchzero/modules/line_search/armijo.py +56 -0
  21. torchzero/modules/line_search/base_ls.py +139 -0
  22. torchzero/modules/line_search/directional_newton.py +217 -0
  23. torchzero/modules/line_search/grid_ls.py +158 -0
  24. torchzero/modules/line_search/scipy_minimize_scalar.py +62 -0
  25. torchzero/modules/meta/__init__.py +12 -0
  26. torchzero/modules/meta/alternate.py +65 -0
  27. torchzero/modules/meta/grafting.py +195 -0
  28. torchzero/modules/meta/optimizer_wrapper.py +173 -0
  29. torchzero/modules/meta/return_overrides.py +46 -0
  30. torchzero/modules/misc/__init__.py +10 -0
  31. torchzero/modules/misc/accumulate.py +43 -0
  32. torchzero/modules/misc/basic.py +115 -0
  33. torchzero/modules/misc/lr.py +96 -0
  34. torchzero/modules/misc/multistep.py +51 -0
  35. torchzero/modules/misc/on_increase.py +53 -0
  36. torchzero/modules/momentum/__init__.py +4 -0
  37. torchzero/modules/momentum/momentum.py +106 -0
  38. torchzero/modules/operations/__init__.py +29 -0
  39. torchzero/modules/operations/multi.py +298 -0
  40. torchzero/modules/operations/reduction.py +134 -0
  41. torchzero/modules/operations/singular.py +113 -0
  42. torchzero/modules/optimizers/__init__.py +10 -0
  43. torchzero/modules/optimizers/adagrad.py +49 -0
  44. torchzero/modules/optimizers/adam.py +118 -0
  45. torchzero/modules/optimizers/lion.py +28 -0
  46. torchzero/modules/optimizers/rmsprop.py +51 -0
  47. torchzero/modules/optimizers/rprop.py +99 -0
  48. torchzero/modules/optimizers/sgd.py +54 -0
  49. torchzero/modules/orthogonalization/__init__.py +2 -0
  50. torchzero/modules/orthogonalization/newtonschulz.py +159 -0
  51. torchzero/modules/orthogonalization/svd.py +86 -0
  52. torchzero/modules/quasi_newton/__init__.py +4 -0
  53. torchzero/modules/regularization/__init__.py +22 -0
  54. torchzero/modules/regularization/dropout.py +34 -0
  55. torchzero/modules/regularization/noise.py +77 -0
  56. torchzero/modules/regularization/normalization.py +328 -0
  57. torchzero/modules/regularization/ortho_grad.py +78 -0
  58. torchzero/modules/regularization/weight_decay.py +92 -0
  59. torchzero/modules/scheduling/__init__.py +2 -0
  60. torchzero/modules/scheduling/lr_schedulers.py +131 -0
  61. torchzero/modules/scheduling/step_size.py +80 -0
  62. torchzero/modules/second_order/__init__.py +4 -0
  63. torchzero/modules/second_order/newton.py +165 -0
  64. torchzero/modules/smoothing/__init__.py +5 -0
  65. torchzero/modules/smoothing/gaussian_smoothing.py +90 -0
  66. torchzero/modules/smoothing/laplacian_smoothing.py +128 -0
  67. torchzero/modules/weight_averaging/__init__.py +2 -0
  68. torchzero/modules/weight_averaging/ema.py +72 -0
  69. torchzero/modules/weight_averaging/swa.py +171 -0
  70. torchzero/optim/__init__.py +10 -0
  71. torchzero/optim/experimental/__init__.py +20 -0
  72. torchzero/optim/experimental/experimental.py +343 -0
  73. torchzero/optim/experimental/ray_search.py +83 -0
  74. torchzero/optim/first_order/__init__.py +18 -0
  75. torchzero/optim/first_order/cautious.py +158 -0
  76. torchzero/optim/first_order/forward_gradient.py +70 -0
  77. torchzero/optim/first_order/optimizers.py +570 -0
  78. torchzero/optim/modular.py +132 -0
  79. torchzero/optim/quasi_newton/__init__.py +1 -0
  80. torchzero/optim/quasi_newton/directional_newton.py +58 -0
  81. torchzero/optim/second_order/__init__.py +1 -0
  82. torchzero/optim/second_order/newton.py +94 -0
  83. torchzero/optim/wrappers/__init__.py +0 -0
  84. torchzero/optim/wrappers/nevergrad.py +113 -0
  85. torchzero/optim/wrappers/nlopt.py +165 -0
  86. torchzero/optim/wrappers/scipy.py +439 -0
  87. torchzero/optim/zeroth_order/__init__.py +4 -0
  88. torchzero/optim/zeroth_order/fdm.py +87 -0
  89. torchzero/optim/zeroth_order/newton_fdm.py +146 -0
  90. torchzero/optim/zeroth_order/rfdm.py +217 -0
  91. torchzero/optim/zeroth_order/rs.py +85 -0
  92. torchzero/random/__init__.py +1 -0
  93. torchzero/random/random.py +46 -0
  94. torchzero/tensorlist.py +819 -0
  95. torchzero/utils/__init__.py +0 -0
  96. torchzero/utils/compile.py +39 -0
  97. torchzero/utils/derivatives.py +99 -0
  98. torchzero/utils/python_tools.py +25 -0
  99. torchzero/utils/torch_tools.py +92 -0
  100. torchzero-0.0.1.dist-info/LICENSE +21 -0
  101. torchzero-0.0.1.dist-info/METADATA +118 -0
  102. torchzero-0.0.1.dist-info/RECORD +104 -0
  103. torchzero-0.0.1.dist-info/WHEEL +5 -0
  104. torchzero-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,439 @@
1
+ from typing import Literal, Any
2
+ from collections import abc
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ import scipy.optimize
9
+
10
+ from ...core import _ClosureType, TensorListOptimizer
11
+ from ...utils.derivatives import jacobian, jacobian_list_to_vec, hessian, hessian_list_to_mat, jacobian_and_hessian
12
+ from ...modules import WrapClosure
13
+ from ...modules.experimental.subspace import Projection, Proj2Masks, ProjGrad, ProjNormalize, Subspace
14
+ from ...modules.second_order.newton import regularize_hessian_
15
+ from ...tensorlist import TensorList
16
+ from ..modular import Modular
17
+
18
+
19
+ def _ensure_float(x):
20
+ if isinstance(x, torch.Tensor): return x.detach().cpu().item()
21
+ if isinstance(x, np.ndarray): return x.item()
22
+ return float(x)
23
+
24
+ def _ensure_numpy(x):
25
+ if isinstance(x, torch.Tensor): return x.detach().cpu()
26
+ if isinstance(x, np.ndarray): return x
27
+ return np.array(x)
28
+
29
+ class ScipyMinimize(TensorListOptimizer):
30
+ """Use scipy.minimize.optimize as pytorch optimizer. Note that this performs full minimization on each step,
31
+ so usually you would want to perform a single step, although performing multiple steps will refine the
32
+ solution.
33
+
34
+ Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
35
+ for a detailed description of args.
36
+
37
+ Args:
38
+ params: iterable of parameters to optimize or dicts defining parameter groups.
39
+ method (str | None, optional): type of solver.
40
+ If None, scipy will select one of BFGS, L-BFGS-B, SLSQP,
41
+ depending on whether or not the problem has constraints or bounds.
42
+ Defaults to None.
43
+ bounds (optional): bounds on variables. Defaults to None.
44
+ constraints (tuple, optional): constraints definition. Defaults to ().
45
+ tol (float | None, optional): Tolerance for termination. Defaults to None.
46
+ callback (Callable | None, optional): A callable called after each iteration. Defaults to None.
47
+ options (dict | None, optional): A dictionary of solver options. Defaults to None.
48
+ jac (str, optional): Method for computing the gradient vector.
49
+ Only for CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
50
+ In addition to scipy options, this supports 'autograd', which uses pytorch autograd.
51
+ This setting is ignored for methods that don't require gradient. Defaults to 'autograd'.
52
+ hess (str, optional):
53
+ Method for computing the Hessian matrix.
54
+ Only for Newton-CG, dogleg, trust-ncg, trust-krylov, trust-exact and trust-constr.
55
+ This setting is ignored for methods that don't require hessian. Defaults to 'autograd'.
56
+ tikhonov (float, optional):
57
+ optional hessian regularizer value. Only has effect for methods that require hessian.
58
+ """
59
+ def __init__(
60
+ self,
61
+ params,
62
+ method: Literal['nelder-mead', 'powell', 'cg', 'bfgs', 'newton-cg',
63
+ 'l-bfgs-b', 'tnc', 'cobyla', 'cobyqa', 'slsqp',
64
+ 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact',
65
+ 'trust-krylov'] | str | None = None,
66
+ lb = None,
67
+ ub = None,
68
+ constraints = (),
69
+ tol: float | None = None,
70
+ callback = None,
71
+ options = None,
72
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
73
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = 'autograd',
74
+ tikhonov: float | Literal['eig'] = 0,
75
+ ):
76
+ defaults = dict(lb=lb, ub=ub)
77
+ super().__init__(params, defaults)
78
+ self.method = method
79
+ self.constraints = constraints
80
+ self.tol = tol
81
+ self.callback = callback
82
+ self.options = options
83
+
84
+ self.jac = jac
85
+ self.hess = hess
86
+ self.tikhonov: float | Literal['eig'] = tikhonov
87
+
88
+ self.use_jac_autograd = jac.lower() == 'autograd' and (method is None or method.lower() in [
89
+ 'cg', 'bfgs', 'newton-cg', 'l-bfgs-b', 'tnc', 'slsqp', 'dogleg',
90
+ 'trust-ncg', 'trust-krylov', 'trust-exact', 'trust-constr',
91
+ ])
92
+ self.use_hess_autograd = isinstance(hess, str) and hess.lower() == 'autograd' and method is not None and method.lower() in [
93
+ 'newton-cg', 'dogleg', 'trust-ncg', 'trust-krylov', 'trust-exact'
94
+ ]
95
+
96
+ if self.jac == 'autograd':
97
+ if self.use_jac_autograd: self.jac = True
98
+ else: self.jac = None
99
+
100
+
101
+ def _hess(self, x: np.ndarray, params: TensorList, closure: _ClosureType): # type:ignore
102
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
103
+ with torch.enable_grad():
104
+ value = closure(False)
105
+ H = hessian([value], wrt = params) # type:ignore
106
+ Hmat = hessian_list_to_mat(H)
107
+ regularize_hessian_(Hmat, self.tikhonov)
108
+ return Hmat.detach().cpu().numpy()
109
+
110
+ def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
111
+ # set params to x
112
+ params.from_vec_(torch.from_numpy(x).to(params[0], copy=False))
113
+
114
+ # return value and maybe gradients
115
+ if self.use_jac_autograd:
116
+ with torch.enable_grad(): value = _ensure_float(closure())
117
+ return value, params.ensure_grad_().grad.to_vec().detach().cpu().numpy()
118
+ return _ensure_float(closure(False))
119
+
120
+ @torch.no_grad
121
+ def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
122
+ params = self.get_params()
123
+
124
+ # determine hess argument
125
+ if self.hess == 'autograd':
126
+ if self.use_hess_autograd: hess = partial(self._hess, params = params, closure = closure)
127
+ else: hess = None
128
+ else: hess = self.hess
129
+
130
+ x0 = params.to_vec().detach().cpu().numpy()
131
+
132
+ # make bounds
133
+ lb, ub = self.get_group_keys('lb', 'ub', cls=list)
134
+ bounds = []
135
+ for p, l, u in zip(params, lb, ub):
136
+ bounds.extend([(l, u)] * p.numel())
137
+
138
+ if self.method is not None and (self.method.lower() == 'tnc' or self.method.lower() == 'slsqp'):
139
+ x0 = x0.astype(np.float64) # those methods error without this
140
+
141
+ res = scipy.optimize.minimize(
142
+ partial(self._objective, params = params, closure = closure),
143
+ x0 = x0,
144
+ method=self.method,
145
+ bounds=bounds,
146
+ constraints=self.constraints,
147
+ tol=self.tol,
148
+ callback=self.callback,
149
+ options=self.options,
150
+ jac = self.jac,
151
+ hess = hess,
152
+ )
153
+
154
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
155
+ return res.fun
156
+
157
+
158
+
159
+ class ScipyRoot(TensorListOptimizer):
160
+ """Find a root of a vector function (UNTESTED!).
161
+
162
+ Args:
163
+ params: iterable of parameters to optimize or dicts defining parameter groups.
164
+ method (str | None, optional): _description_. Defaults to None.
165
+ tol (float | None, optional): _description_. Defaults to None.
166
+ callback (_type_, optional): _description_. Defaults to None.
167
+ options (_type_, optional): _description_. Defaults to None.
168
+ jac (T.Literal['2, optional): _description_. Defaults to 'autograd'.
169
+ """
170
+ def __init__(
171
+ self,
172
+ params,
173
+ method: Literal[
174
+ "hybr",
175
+ "lm",
176
+ "broyden1",
177
+ "broyden2",
178
+ "anderson",
179
+ "linearmixing",
180
+ "diagbroyden",
181
+ "excitingmixing",
182
+ "krylov",
183
+ "df-sane",
184
+ ] = 'hybr',
185
+ tol: float | None = None,
186
+ callback = None,
187
+ options = None,
188
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
189
+ ):
190
+ super().__init__(params, {})
191
+ self.method = method
192
+ self.tol = tol
193
+ self.callback = callback
194
+ self.options = options
195
+
196
+ self.jac = jac
197
+ if self.jac == 'autograd': self.jac = True
198
+
199
+ def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
200
+ # set params to x
201
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
202
+
203
+ # return value and maybe gradients
204
+ if self.jac:
205
+ with torch.enable_grad():
206
+ value = closure(False)
207
+ if not isinstance(value, torch.Tensor):
208
+ raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(value)}")
209
+ jac = jacobian_list_to_vec(jacobian([value], wrt=params))
210
+ return _ensure_numpy(value), jac.detach().cpu().numpy()
211
+ return _ensure_numpy(closure(False))
212
+
213
+ @torch.no_grad
214
+ def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
215
+ params = self.get_params()
216
+
217
+ x0 = params.to_vec().detach().cpu().numpy()
218
+
219
+ res = scipy.optimize.root(
220
+ partial(self._objective, params = params, closure = closure),
221
+ x0 = x0,
222
+ method=self.method,
223
+ tol=self.tol,
224
+ callback=self.callback,
225
+ options=self.options,
226
+ jac = self.jac,
227
+ )
228
+
229
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
230
+ return res.fun
231
+
232
+
233
+ class ScipyRootOptimization(TensorListOptimizer):
234
+ """Optimization via finding roots of the gradient with `scipy.optimize.root` (for experiments, won't work well on most problems).
235
+
236
+ Args:
237
+ params: iterable of parameters to optimize or dicts defining parameter groups.
238
+ method (str, optional): one of methods from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html#scipy.optimize.root. Defaults to 'hybr'.
239
+ tol (float | None, optional): tolerance. Defaults to None.
240
+ callback (_type_, optional): callback. Defaults to None.
241
+ options (_type_, optional): options for optimizer. Defaults to None.
242
+ jac (Literal['2, optional): jacobian calculation method. Defaults to 'autograd'.
243
+ tikhonov (float | Literal['eig'], optional): tikhonov regularization (only for 'hybr' and 'lm'). Defaults to 0.
244
+ add_loss (float, optional): adds loss value to jacobian multiplied by this to try to avoid finding maxima. Defaults to 0.
245
+ mul_loss (float, optional): multiplies jacobian by loss value multiplied by this to try to avoid finding maxima. Defaults to 0.
246
+ """
247
+ def __init__(
248
+ self,
249
+ params,
250
+ method: Literal[
251
+ "hybr",
252
+ "lm",
253
+ "broyden1",
254
+ "broyden2",
255
+ "anderson",
256
+ "linearmixing",
257
+ "diagbroyden",
258
+ "excitingmixing",
259
+ "krylov",
260
+ "df-sane",
261
+ ] = 'hybr',
262
+ tol: float | None = None,
263
+ callback = None,
264
+ options = None,
265
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
266
+ tikhonov: float | Literal['eig'] = 0,
267
+ add_loss: float = 0,
268
+ mul_loss: float = 0,
269
+ ):
270
+ super().__init__(params, {})
271
+ self.method = method
272
+ self.tol = tol
273
+ self.callback = callback
274
+ self.options = options
275
+ self.value = None
276
+ self.tikhonov: float | Literal['eig'] = tikhonov
277
+ self.add_loss = add_loss
278
+ self.mul_loss = mul_loss
279
+
280
+ self.jac = jac == 'autograd'
281
+
282
+ # those don't require jacobian
283
+ if self.method.lower() in ('broyden1', 'broyden2', 'anderson', 'linearmixing', 'diagbroyden', 'excitingmixing', 'krylov', 'df-sane'):
284
+ self.jac = None
285
+
286
+ def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
287
+ # set params to x
288
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
289
+
290
+ # return gradients and maybe hessian
291
+ if self.jac:
292
+ with torch.enable_grad():
293
+ self.value = closure(False)
294
+ if not isinstance(self.value, torch.Tensor):
295
+ raise TypeError(f"Autograd jacobian requires closure to return torch.Tensor, got {type(self.value)}")
296
+ jac_list, hess_list = jacobian_and_hessian([self.value], wrt=params)
297
+ jac = jacobian_list_to_vec(jac_list)
298
+ hess = hessian_list_to_mat(hess_list)
299
+ regularize_hessian_(hess, self.tikhonov)
300
+ if self.mul_loss != 0: jac *= self.value * self.mul_loss
301
+ if self.add_loss != 0: jac += self.value * self.add_loss
302
+ return jac.detach().cpu().numpy(), hess.detach().cpu().numpy()
303
+
304
+ # return the gradients
305
+ with torch.enable_grad(): self.value = closure()
306
+ jac = params.ensure_grad_().grad.to_vec()
307
+ if self.mul_loss != 0: jac *= self.value * self.mul_loss
308
+ if self.add_loss != 0: jac += self.value * self.add_loss
309
+ return jac.detach().cpu().numpy()
310
+
311
+ @torch.no_grad
312
+ def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
313
+ params = self.get_params()
314
+
315
+ x0 = params.to_vec().detach().cpu().numpy()
316
+
317
+ res = scipy.optimize.root(
318
+ partial(self._objective, params = params, closure = closure),
319
+ x0 = x0,
320
+ method=self.method,
321
+ tol=self.tol,
322
+ callback=self.callback,
323
+ options=self.options,
324
+ jac = self.jac,
325
+ )
326
+
327
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
328
+ return self.value
329
+
330
+ class ScipyDE(TensorListOptimizer):
331
+ """Use scipy.minimize.differential_evolution as pytorch optimizer. Note that this performs full minimization on each step,
332
+ so usually you would want to perform a single step. This also requires bounds to be specified.
333
+
334
+ Please refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
335
+ for all other args.
336
+
337
+ Args:
338
+ params: iterable of parameters to optimize or dicts defining parameter groups.
339
+ bounds (tuple[float,float], optional): tuple with lower and upper bounds.
340
+ DE requires bounds to be specified. Defaults to None.
341
+
342
+ other args:
343
+ refer to https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.differential_evolution.html
344
+ """
345
+ def __init__(
346
+ self,
347
+ params,
348
+ bounds: tuple[float,float],
349
+ strategy: Literal['best1bin', 'best1exp', 'rand1bin', 'rand1exp', 'rand2bin', 'rand2exp',
350
+ 'randtobest1bin', 'randtobest1exp', 'currenttobest1bin', 'currenttobest1exp',
351
+ 'best2exp', 'best2bin'] = 'best1bin',
352
+ maxiter: int = 1000,
353
+ popsize: int = 15,
354
+ tol: float = 0.01,
355
+ mutation = (0.5, 1),
356
+ recombination: float = 0.7,
357
+ seed = None,
358
+ callback = None,
359
+ disp: bool = False,
360
+ polish: bool = False,
361
+ init: str = 'latinhypercube',
362
+ atol: int = 0,
363
+ updating: str = 'immediate',
364
+ workers: int = 1,
365
+ constraints = (),
366
+ *,
367
+ integrality = None,
368
+
369
+ ):
370
+ super().__init__(params, {})
371
+
372
+ kwargs = locals().copy()
373
+ del kwargs['self'], kwargs['params'], kwargs['bounds'], kwargs['__class__']
374
+ self._kwargs = kwargs
375
+ self._lb, self._ub = bounds
376
+
377
+ def _objective(self, x: np.ndarray, params: TensorList, closure: _ClosureType):
378
+ params.from_vec_(torch.from_numpy(x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
379
+ return _ensure_float(closure(False))
380
+
381
+ @torch.no_grad
382
+ def step(self, closure: _ClosureType): # type:ignore # pylint:disable = signature-differs
383
+ params = self.get_params()
384
+
385
+ x0 = params.to_vec().detach().cpu().numpy()
386
+ bounds = [(self._lb, self._ub)] * len(x0)
387
+
388
+ res = scipy.optimize.differential_evolution(
389
+ partial(self._objective, params = params, closure = closure),
390
+ x0 = x0,
391
+ bounds=bounds,
392
+ **self._kwargs
393
+ )
394
+
395
+ params.from_vec_(torch.from_numpy(res.x).to(device = params[0].device, dtype=params[0].dtype, copy=False))
396
+ return res.fun
397
+
398
+
399
+ class ScipyMinimizeSubspace(Modular):
400
+ """for experiments and won't work well on most problems.
401
+
402
+ explanation - optimizes in a small subspace using scipy.optimize.minimize, but doesnt seem to work well"""
403
+ def __init__(
404
+ self,
405
+ params,
406
+ projections: Projection | abc.Iterable[Projection] = (
407
+ Proj2Masks(5),
408
+ ProjNormalize(
409
+ ProjGrad(),
410
+ )
411
+ ),
412
+ method=None,
413
+ lb = None,
414
+ ub = None,
415
+ constraints=(),
416
+ tol=None,
417
+ callback=None,
418
+ options=None,
419
+ jac: Literal['2-point', '3-point', 'cs', 'autograd'] = 'autograd',
420
+ hess: Literal['2-point', '3-point', 'cs', 'autograd'] | scipy.optimize.HessianUpdateStrategy = '2-point',
421
+ ):
422
+
423
+ scopt = WrapClosure(
424
+ ScipyMinimize,
425
+ method = method,
426
+ lb = lb,
427
+ ub = ub,
428
+ constraints = constraints,
429
+ tol = tol,
430
+ callback = callback,
431
+ options = options,
432
+ jac = jac,
433
+ hess = hess
434
+ )
435
+ modules = [
436
+ Subspace(scopt, projections),
437
+ ]
438
+
439
+ super().__init__(params, modules)
@@ -0,0 +1,4 @@
1
+ from .fdm import FDM, FDMWrapper
2
+ from .newton_fdm import NewtonFDM, RandomSubspaceNewtonFDM
3
+ from .rfdm import RandomGaussianSmoothing, RandomizedFDM, RandomizedFDMWrapper, SPSA
4
+ from .rs import RandomSearch, CyclicRS
@@ -0,0 +1,87 @@
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+ from ...modules import FDM as _FDM, WrapClosure, SGD, WeightDecay, LR
6
+ from ...modules.gradient_approximation._fd_formulas import _FD_Formulas
7
+ from ..modular import Modular
8
+
9
+
10
+ class FDM(Modular):
11
+ """Gradient approximation via finite difference.
12
+
13
+ This performs `n + 1` evaluations per step with `forward` and `backward` formulas,
14
+ and `2 * n` with `central` formula, where n is the number of parameters.
15
+
16
+ Args:
17
+ params: iterable of parameters to optimize or dicts defining parameter groups.
18
+ lr (float, optional): learning rate. Defaults to 1e-3.
19
+ eps (float, optional): finite difference epsilon. Defaults to 1e-3.
20
+ formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
21
+ n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
22
+ momentum (float, optional): momentum. Defaults to 0.
23
+ dampening (float, optional): momentum dampening. Defaults to 0.
24
+ nesterov (bool, optional):
25
+ enables nesterov momentum, otherwise uses heavyball momentum. Defaults to False.
26
+ weight_decay (float, optional): weight decay (L2 regularization). Defaults to 0.
27
+ decoupled (bool, optional):
28
+ decouples weight decay from gradient. If True, weight decay doesn't depend on learning rate.
29
+ """
30
+ def __init__(
31
+ self,
32
+ params,
33
+ lr: float = 1e-3,
34
+ eps: float = 1e-3,
35
+ formula: _FD_Formulas = "forward",
36
+ n_points: Literal[2, 3] = 2,
37
+ momentum: float = 0,
38
+ dampening: float = 0,
39
+ nesterov: bool = False,
40
+ weight_decay: float = 0,
41
+ decoupled=False,
42
+
43
+ ):
44
+ modules: list = [
45
+ _FDM(eps = eps, formula=formula, n_points=n_points),
46
+ SGD(momentum = momentum, dampening = dampening, weight_decay = weight_decay if not decoupled else 0, nesterov = nesterov),
47
+ LR(lr),
48
+
49
+ ]
50
+ if decoupled: modules.append(WeightDecay(weight_decay))
51
+ super().__init__(params, modules)
52
+
53
+
54
+ class FDMWrapper(Modular):
55
+ """Gradient approximation via finite difference. This wraps any other optimizer.
56
+ This also supports optimizers that perform multiple gradient evaluations per step, like LBFGS.
57
+
58
+ Exaple:
59
+ ```
60
+ lbfgs = torch.optim.LBFGS(params, lr = 1)
61
+ fdm = FDMWrapper(optimizer = lbfgs)
62
+ ```
63
+
64
+ This performs n+1 evaluations per step with `forward` and `backward` formulas,
65
+ and 2*n with `central` formula.
66
+
67
+ Args:
68
+ params: iterable of parameters to optimize or dicts defining parameter groups.
69
+ optimizer (torch.optim.Optimizer): optimizer that will perform optimization using FDM-approximated gradients.
70
+ eps (float, optional): finite difference epsilon. Defaults to 1e-3.
71
+ formula (_FD_Formulas, optional): finite difference formula. Defaults to "forward".
72
+ n_points (T.Literal[2, 3], optional): number of points for finite difference formula, 2 or 3. Defaults to 2.
73
+ """
74
+ def __init__(
75
+ self,
76
+ optimizer: torch.optim.Optimizer,
77
+ eps: float = 1e-3,
78
+ formula: _FD_Formulas = "forward",
79
+ n_points: Literal[2, 3] = 2,
80
+ ):
81
+ modules = [
82
+ _FDM(eps = eps, formula=formula, n_points=n_points, target = 'closure'),
83
+ WrapClosure(optimizer)
84
+ ]
85
+ # some optimizers have `eps` setting in param groups too.
86
+ # it should not be passed to FDM
87
+ super().__init__([p for g in optimizer.param_groups.copy() for p in g['params']], modules)
@@ -0,0 +1,146 @@
1
+ from typing import Any, Literal
2
+ import torch
3
+
4
+ from ...modules import (LR, FallbackLinearSystemSolvers,
5
+ LinearSystemSolvers, LineSearches, ClipNorm)
6
+ from ...modules import NewtonFDM as _NewtonFDM, get_line_search
7
+ from ...modules.experimental.subspace import Proj2Masks, ProjRandom, Subspace
8
+ from ..modular import Modular
9
+
10
+
11
+ class NewtonFDM(Modular):
12
+ """Newton method with gradient and hessian approximated via finite difference.
13
+
14
+ This performs approximately `4 * n^2 + 1` evaluations per step;
15
+ if `diag` is True, performs `n * 2 + 1` evaluations per step.
16
+
17
+ Args:
18
+ params: iterable of parameters to optimize or dicts defining parameter groups.
19
+ lr (float, optional): learning rate.
20
+ eps (float, optional): epsilon for finite difference.
21
+ Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
22
+ diag (bool, optional): whether to only approximate diagonal elements of the hessian.
23
+ This also ignores `solver` if True. Defaults to False.
24
+ solver (LinearSystemSolvers, optional):
25
+ solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
26
+ fallback (FallbackLinearSystemSolvers, optional):
27
+ what to do if solver fails. Defaults to "safe_diag"
28
+ (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
29
+ validate (bool, optional):
30
+ validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
31
+ If not, undo the step and perform a gradient descent step.
32
+ tol (float, optional):
33
+ only has effect if `validate` is enabled.
34
+ If loss increased by `loss * tol`, perform gradient descent step.
35
+ Set this to 0 to guarantee that loss always decreases. Defaults to 1.
36
+ gd_lr (float, optional):
37
+ only has effect if `validate` is enabled.
38
+ Gradient descent step learning rate. Defaults to 1e-2.
39
+ line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to 'brent'.
40
+ """
41
+ def __init__(
42
+ self,
43
+ params,
44
+ lr: float = 1,
45
+ eps: float = 1e-2,
46
+ diag=False,
47
+ solver: LinearSystemSolvers = "cholesky_lu",
48
+ fallback: FallbackLinearSystemSolvers = "safe_diag",
49
+ max_norm: float | None = None,
50
+ validate=False,
51
+ tol: float = 2,
52
+ gd_lr = 1e-2,
53
+ line_search: LineSearches | None = 'brent',
54
+ ):
55
+ modules: list[Any] = [
56
+ _NewtonFDM(eps = eps, diag = diag, solver=solver, fallback=fallback, validate=validate, tol=tol, gd_lr=gd_lr),
57
+ ]
58
+
59
+ if max_norm is not None:
60
+ modules.append(ClipNorm(max_norm))
61
+
62
+ modules.append(LR(lr))
63
+
64
+ if line_search is not None:
65
+ modules.append(get_line_search(line_search))
66
+
67
+ super().__init__(params, modules)
68
+
69
+
70
+ class RandomSubspaceNewtonFDM(Modular):
71
+ """This projects the parameters into a smaller dimensional subspace,
72
+ making approximating the hessian via finite difference feasible.
73
+
74
+ This performs approximately `4 * subspace_ndim^2 + 1` evaluations per step;
75
+ if `diag` is True, performs `subspace_ndim * 2 + 1` evaluations per step.
76
+
77
+ Args:
78
+ params: iterable of parameters to optimize or dicts defining parameter groups.
79
+ subspace_ndim (float, optional): number of random subspace dimensions.
80
+ lr (float, optional): learning rate.
81
+ eps (float, optional): epsilon for finite difference.
82
+ Note that with float32 this needs to be quite high to avoid numerical instability. Defaults to 1e-2.
83
+ diag (bool, optional): whether to only approximate diagonal elements of the hessian.
84
+ solver (LinearSystemSolvers, optional):
85
+ solver for Hx = g. Defaults to "cholesky_lu" (cholesky or LU if it fails).
86
+ fallback (FallbackLinearSystemSolvers, optional):
87
+ what to do if solver fails. Defaults to "safe_diag"
88
+ (takes nonzero diagonal elements, or fallbacks to gradient descent if all elements are 0).
89
+ validate (bool, optional):
90
+ validate if the step didn't increase the loss by `loss * tol` with an additional forward pass.
91
+ If not, undo the step and perform a gradient descent step.
92
+ tol (float, optional):
93
+ only has effect if `validate` is enabled.
94
+ If loss increased by `loss * tol`, perform gradient descent step.
95
+ Set this to 0 to guarantee that loss always decreases. Defaults to 1.
96
+ gd_lr (float, optional):
97
+ only has effect if `validate` is enabled.
98
+ Gradient descent step learning rate. Defaults to 1e-2.
99
+ line_search (OptimizerModule | None, optional): line search module, can be None. Defaults to BacktrackingLS().
100
+ randomize_every (float, optional): generates new random projections every n steps. Defaults to 1.
101
+ """
102
+ def __init__(
103
+ self,
104
+ params,
105
+ subspace_ndim: int = 3,
106
+ lr: float = 1,
107
+ eps: float = 1e-2,
108
+ diag=False,
109
+ solver: LinearSystemSolvers = "cholesky_lu",
110
+ fallback: FallbackLinearSystemSolvers = "safe_diag",
111
+ max_norm: float | None = None,
112
+ validate=False,
113
+ tol: float = 2,
114
+ gd_lr = 1e-2,
115
+ line_search: LineSearches | None = 'brent',
116
+ randomize_every: int = 1,
117
+ ):
118
+ if subspace_ndim == 1: projections = [ProjRandom(1)]
119
+ else:
120
+ projections: list[Any] = [Proj2Masks(subspace_ndim//2)]
121
+ if subspace_ndim % 2 == 1: projections.append(ProjRandom(1))
122
+
123
+ modules: list[Any] = [
124
+ Subspace(
125
+ modules = _NewtonFDM(
126
+ eps = eps,
127
+ diag = diag,
128
+ solver=solver,
129
+ fallback=fallback,
130
+ validate=validate,
131
+ tol=tol,
132
+ gd_lr=gd_lr
133
+ ),
134
+ projections = projections,
135
+ update_every=randomize_every),
136
+ ]
137
+ if max_norm is not None:
138
+ modules.append(ClipNorm(max_norm))
139
+
140
+ modules.append(LR(lr))
141
+
142
+ if line_search is not None:
143
+ modules.append(get_line_search(line_search))
144
+
145
+ super().__init__(params, modules)
146
+