torchzero 0.1.8__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. docs/source/conf.py +57 -0
  2. tests/test_identical.py +230 -0
  3. tests/test_module.py +50 -0
  4. tests/test_opts.py +884 -0
  5. tests/test_tensorlist.py +1787 -0
  6. tests/test_utils_optimizer.py +170 -0
  7. tests/test_vars.py +184 -0
  8. torchzero/__init__.py +4 -4
  9. torchzero/core/__init__.py +3 -13
  10. torchzero/core/module.py +629 -510
  11. torchzero/core/preconditioner.py +137 -0
  12. torchzero/core/transform.py +252 -0
  13. torchzero/modules/__init__.py +13 -21
  14. torchzero/modules/clipping/__init__.py +3 -0
  15. torchzero/modules/clipping/clipping.py +320 -0
  16. torchzero/modules/clipping/ema_clipping.py +135 -0
  17. torchzero/modules/clipping/growth_clipping.py +187 -0
  18. torchzero/modules/experimental/__init__.py +13 -18
  19. torchzero/modules/experimental/absoap.py +350 -0
  20. torchzero/modules/experimental/adadam.py +111 -0
  21. torchzero/modules/experimental/adamY.py +135 -0
  22. torchzero/modules/experimental/adasoap.py +282 -0
  23. torchzero/modules/experimental/algebraic_newton.py +145 -0
  24. torchzero/modules/experimental/curveball.py +89 -0
  25. torchzero/modules/experimental/dsoap.py +290 -0
  26. torchzero/modules/experimental/gradmin.py +85 -0
  27. torchzero/modules/experimental/reduce_outward_lr.py +35 -0
  28. torchzero/modules/experimental/spectral.py +286 -0
  29. torchzero/modules/experimental/subspace_preconditioners.py +128 -0
  30. torchzero/modules/experimental/tropical_newton.py +136 -0
  31. torchzero/modules/functional.py +209 -0
  32. torchzero/modules/grad_approximation/__init__.py +4 -0
  33. torchzero/modules/grad_approximation/fdm.py +120 -0
  34. torchzero/modules/grad_approximation/forward_gradient.py +81 -0
  35. torchzero/modules/grad_approximation/grad_approximator.py +66 -0
  36. torchzero/modules/grad_approximation/rfdm.py +259 -0
  37. torchzero/modules/line_search/__init__.py +5 -30
  38. torchzero/modules/line_search/backtracking.py +186 -0
  39. torchzero/modules/line_search/line_search.py +181 -0
  40. torchzero/modules/line_search/scipy.py +37 -0
  41. torchzero/modules/line_search/strong_wolfe.py +260 -0
  42. torchzero/modules/line_search/trust_region.py +61 -0
  43. torchzero/modules/lr/__init__.py +2 -0
  44. torchzero/modules/lr/lr.py +59 -0
  45. torchzero/modules/lr/step_size.py +97 -0
  46. torchzero/modules/momentum/__init__.py +14 -4
  47. torchzero/modules/momentum/averaging.py +78 -0
  48. torchzero/modules/momentum/cautious.py +181 -0
  49. torchzero/modules/momentum/ema.py +173 -0
  50. torchzero/modules/momentum/experimental.py +189 -0
  51. torchzero/modules/momentum/matrix_momentum.py +124 -0
  52. torchzero/modules/momentum/momentum.py +43 -106
  53. torchzero/modules/ops/__init__.py +103 -0
  54. torchzero/modules/ops/accumulate.py +65 -0
  55. torchzero/modules/ops/binary.py +240 -0
  56. torchzero/modules/ops/debug.py +25 -0
  57. torchzero/modules/ops/misc.py +419 -0
  58. torchzero/modules/ops/multi.py +137 -0
  59. torchzero/modules/ops/reduce.py +149 -0
  60. torchzero/modules/ops/split.py +75 -0
  61. torchzero/modules/ops/switch.py +68 -0
  62. torchzero/modules/ops/unary.py +115 -0
  63. torchzero/modules/ops/utility.py +112 -0
  64. torchzero/modules/optimizers/__init__.py +18 -10
  65. torchzero/modules/optimizers/adagrad.py +146 -49
  66. torchzero/modules/optimizers/adam.py +112 -118
  67. torchzero/modules/optimizers/lion.py +18 -11
  68. torchzero/modules/optimizers/muon.py +222 -0
  69. torchzero/modules/optimizers/orthograd.py +55 -0
  70. torchzero/modules/optimizers/rmsprop.py +103 -51
  71. torchzero/modules/optimizers/rprop.py +342 -99
  72. torchzero/modules/optimizers/shampoo.py +197 -0
  73. torchzero/modules/optimizers/soap.py +286 -0
  74. torchzero/modules/optimizers/sophia_h.py +129 -0
  75. torchzero/modules/projections/__init__.py +5 -0
  76. torchzero/modules/projections/dct.py +73 -0
  77. torchzero/modules/projections/fft.py +73 -0
  78. torchzero/modules/projections/galore.py +10 -0
  79. torchzero/modules/projections/projection.py +218 -0
  80. torchzero/modules/projections/structural.py +151 -0
  81. torchzero/modules/quasi_newton/__init__.py +7 -4
  82. torchzero/modules/quasi_newton/cg.py +218 -0
  83. torchzero/modules/quasi_newton/experimental/__init__.py +1 -0
  84. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +265 -0
  85. torchzero/modules/quasi_newton/lbfgs.py +228 -0
  86. torchzero/modules/quasi_newton/lsr1.py +170 -0
  87. torchzero/modules/quasi_newton/olbfgs.py +196 -0
  88. torchzero/modules/quasi_newton/quasi_newton.py +475 -0
  89. torchzero/modules/second_order/__init__.py +3 -4
  90. torchzero/modules/second_order/newton.py +142 -165
  91. torchzero/modules/second_order/newton_cg.py +84 -0
  92. torchzero/modules/second_order/nystrom.py +168 -0
  93. torchzero/modules/smoothing/__init__.py +2 -5
  94. torchzero/modules/smoothing/gaussian.py +164 -0
  95. torchzero/modules/smoothing/{laplacian_smoothing.py → laplacian.py} +115 -128
  96. torchzero/modules/weight_decay/__init__.py +1 -0
  97. torchzero/modules/weight_decay/weight_decay.py +52 -0
  98. torchzero/modules/wrappers/__init__.py +1 -0
  99. torchzero/modules/wrappers/optim_wrapper.py +91 -0
  100. torchzero/optim/__init__.py +2 -10
  101. torchzero/optim/utility/__init__.py +1 -0
  102. torchzero/optim/utility/split.py +45 -0
  103. torchzero/optim/wrappers/nevergrad.py +2 -28
  104. torchzero/optim/wrappers/nlopt.py +31 -16
  105. torchzero/optim/wrappers/scipy.py +79 -156
  106. torchzero/utils/__init__.py +27 -0
  107. torchzero/utils/compile.py +175 -37
  108. torchzero/utils/derivatives.py +513 -99
  109. torchzero/utils/linalg/__init__.py +5 -0
  110. torchzero/utils/linalg/matrix_funcs.py +87 -0
  111. torchzero/utils/linalg/orthogonalize.py +11 -0
  112. torchzero/utils/linalg/qr.py +71 -0
  113. torchzero/utils/linalg/solve.py +168 -0
  114. torchzero/utils/linalg/svd.py +20 -0
  115. torchzero/utils/numberlist.py +132 -0
  116. torchzero/utils/ops.py +10 -0
  117. torchzero/utils/optimizer.py +284 -0
  118. torchzero/utils/optuna_tools.py +40 -0
  119. torchzero/utils/params.py +149 -0
  120. torchzero/utils/python_tools.py +40 -25
  121. torchzero/utils/tensorlist.py +1081 -0
  122. torchzero/utils/torch_tools.py +48 -12
  123. torchzero-0.3.2.dist-info/METADATA +379 -0
  124. torchzero-0.3.2.dist-info/RECORD +128 -0
  125. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info}/WHEEL +1 -1
  126. {torchzero-0.1.8.dist-info → torchzero-0.3.2.dist-info/licenses}/LICENSE +0 -0
  127. torchzero-0.3.2.dist-info/top_level.txt +3 -0
  128. torchzero/core/tensorlist_optimizer.py +0 -219
  129. torchzero/modules/adaptive/__init__.py +0 -4
  130. torchzero/modules/adaptive/adaptive.py +0 -192
  131. torchzero/modules/experimental/experimental.py +0 -294
  132. torchzero/modules/experimental/quad_interp.py +0 -104
  133. torchzero/modules/experimental/subspace.py +0 -259
  134. torchzero/modules/gradient_approximation/__init__.py +0 -7
  135. torchzero/modules/gradient_approximation/_fd_formulas.py +0 -3
  136. torchzero/modules/gradient_approximation/base_approximator.py +0 -105
  137. torchzero/modules/gradient_approximation/fdm.py +0 -125
  138. torchzero/modules/gradient_approximation/forward_gradient.py +0 -163
  139. torchzero/modules/gradient_approximation/newton_fdm.py +0 -198
  140. torchzero/modules/gradient_approximation/rfdm.py +0 -125
  141. torchzero/modules/line_search/armijo.py +0 -56
  142. torchzero/modules/line_search/base_ls.py +0 -139
  143. torchzero/modules/line_search/directional_newton.py +0 -217
  144. torchzero/modules/line_search/grid_ls.py +0 -158
  145. torchzero/modules/line_search/scipy_minimize_scalar.py +0 -62
  146. torchzero/modules/meta/__init__.py +0 -12
  147. torchzero/modules/meta/alternate.py +0 -65
  148. torchzero/modules/meta/grafting.py +0 -195
  149. torchzero/modules/meta/optimizer_wrapper.py +0 -173
  150. torchzero/modules/meta/return_overrides.py +0 -46
  151. torchzero/modules/misc/__init__.py +0 -10
  152. torchzero/modules/misc/accumulate.py +0 -43
  153. torchzero/modules/misc/basic.py +0 -115
  154. torchzero/modules/misc/lr.py +0 -96
  155. torchzero/modules/misc/multistep.py +0 -51
  156. torchzero/modules/misc/on_increase.py +0 -53
  157. torchzero/modules/operations/__init__.py +0 -29
  158. torchzero/modules/operations/multi.py +0 -298
  159. torchzero/modules/operations/reduction.py +0 -134
  160. torchzero/modules/operations/singular.py +0 -113
  161. torchzero/modules/optimizers/sgd.py +0 -54
  162. torchzero/modules/orthogonalization/__init__.py +0 -2
  163. torchzero/modules/orthogonalization/newtonschulz.py +0 -159
  164. torchzero/modules/orthogonalization/svd.py +0 -86
  165. torchzero/modules/regularization/__init__.py +0 -22
  166. torchzero/modules/regularization/dropout.py +0 -34
  167. torchzero/modules/regularization/noise.py +0 -77
  168. torchzero/modules/regularization/normalization.py +0 -328
  169. torchzero/modules/regularization/ortho_grad.py +0 -78
  170. torchzero/modules/regularization/weight_decay.py +0 -92
  171. torchzero/modules/scheduling/__init__.py +0 -2
  172. torchzero/modules/scheduling/lr_schedulers.py +0 -131
  173. torchzero/modules/scheduling/step_size.py +0 -80
  174. torchzero/modules/smoothing/gaussian_smoothing.py +0 -90
  175. torchzero/modules/weight_averaging/__init__.py +0 -2
  176. torchzero/modules/weight_averaging/ema.py +0 -72
  177. torchzero/modules/weight_averaging/swa.py +0 -171
  178. torchzero/optim/experimental/__init__.py +0 -20
  179. torchzero/optim/experimental/experimental.py +0 -343
  180. torchzero/optim/experimental/ray_search.py +0 -83
  181. torchzero/optim/first_order/__init__.py +0 -18
  182. torchzero/optim/first_order/cautious.py +0 -158
  183. torchzero/optim/first_order/forward_gradient.py +0 -70
  184. torchzero/optim/first_order/optimizers.py +0 -570
  185. torchzero/optim/modular.py +0 -148
  186. torchzero/optim/quasi_newton/__init__.py +0 -1
  187. torchzero/optim/quasi_newton/directional_newton.py +0 -58
  188. torchzero/optim/second_order/__init__.py +0 -1
  189. torchzero/optim/second_order/newton.py +0 -94
  190. torchzero/optim/zeroth_order/__init__.py +0 -4
  191. torchzero/optim/zeroth_order/fdm.py +0 -87
  192. torchzero/optim/zeroth_order/newton_fdm.py +0 -146
  193. torchzero/optim/zeroth_order/rfdm.py +0 -217
  194. torchzero/optim/zeroth_order/rs.py +0 -85
  195. torchzero/random/__init__.py +0 -1
  196. torchzero/random/random.py +0 -46
  197. torchzero/tensorlist.py +0 -826
  198. torchzero-0.1.8.dist-info/METADATA +0 -130
  199. torchzero-0.1.8.dist-info/RECORD +0 -104
  200. torchzero-0.1.8.dist-info/top_level.txt +0 -1
@@ -0,0 +1,66 @@
1
+ import warnings
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Iterable
4
+ from typing import Any, Literal
5
+
6
+ import torch
7
+
8
+ from ...core import Module, Vars
9
+
10
+ GradTarget = Literal['update', 'grad', 'closure']
11
+ _Scalar = torch.Tensor | float
12
+
13
+ class GradApproximator(Module, ABC):
14
+ """Base class for gradient approximations.
15
+ This is an abstract class, to use it, subclass it and override `approximate`.
16
+
17
+ Args:
18
+ defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
19
+ target (str, optional):
20
+ whether to set `vars.grad`, `vars.update` or 'vars.closure`. Defaults to 'closure'.
21
+ """
22
+ def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
23
+ super().__init__(defaults)
24
+ self._target: GradTarget = target
25
+
26
+ @abstractmethod
27
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, vars: Vars) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
28
+ """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
29
+
30
+ def pre_step(self, vars: Vars) -> Vars | None:
31
+ """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
32
+ evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
33
+ return vars
34
+
35
+ @torch.no_grad
36
+ def step(self, vars):
37
+ ret = self.pre_step(vars)
38
+ if isinstance(ret, Vars): vars = ret
39
+
40
+ if vars.closure is None: raise RuntimeError("Gradient approximation requires closure")
41
+ params, closure, loss = vars.params, vars.closure, vars.loss
42
+
43
+ if self._target == 'closure':
44
+
45
+ def approx_closure(backward=True):
46
+ if backward:
47
+ # set loss to None because closure might be evaluated at different points
48
+ grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, vars=vars)
49
+ for p, g in zip(params, grad): p.grad = g
50
+ return l if l is not None else l_approx
51
+ return closure(False)
52
+
53
+ vars.closure = approx_closure
54
+ return vars
55
+
56
+ # if vars.grad is not None:
57
+ # warnings.warn('Using grad approximator when `vars.grad` is already set.')
58
+ grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, vars=vars)
59
+ if loss_approx is not None: vars.loss_approx = loss_approx
60
+ if loss is not None: vars.loss = vars.loss_approx = loss
61
+ if self._target == 'grad': vars.grad = list(grad)
62
+ elif self._target == 'update': vars.update = list(grad)
63
+ else: raise ValueError(self._target)
64
+ return vars
65
+
66
+ _FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
@@ -0,0 +1,259 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+ from functools import partial
4
+ import torch
5
+
6
+ from ...utils import TensorList, Distributions, NumberList, generic_eq
7
+ from .grad_approximator import GradApproximator, GradTarget, _FD_Formula
8
+
9
+
10
+ def _rforward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
11
+ """p_fn is a function that returns the perturbation.
12
+ It may return pre-generated one or generate one deterministically from a seed as in MeZO.
13
+ Returned perturbation must be multiplied by `h`."""
14
+ if v_0 is None: v_0 = closure(False)
15
+ params += p_fn()
16
+ v_plus = closure(False)
17
+ params -= p_fn()
18
+ h = h**2 # because perturbation already multiplied by h
19
+ return v_0, v_0, (v_plus - v_0) / h # (loss, loss_approx, grad)
20
+
21
+ def _rbackward2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
22
+ if v_0 is None: v_0 = closure(False)
23
+ params -= p_fn()
24
+ v_minus = closure(False)
25
+ params += p_fn()
26
+ h = h**2 # because perturbation already multiplied by h
27
+ return v_0, v_0, (v_0 - v_minus) / h
28
+
29
+ def _rcentral2(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: Any):
30
+ params += p_fn()
31
+ v_plus = closure(False)
32
+
33
+ params -= p_fn() * 2
34
+ v_minus = closure(False)
35
+
36
+ params += p_fn()
37
+ h = h**2 # because perturbation already multiplied by h
38
+ return v_0, v_plus, (v_plus - v_minus) / (2 * h)
39
+
40
+ def _rforward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
41
+ if v_0 is None: v_0 = closure(False)
42
+ params += p_fn()
43
+ v_plus1 = closure(False)
44
+
45
+ params += p_fn()
46
+ v_plus2 = closure(False)
47
+
48
+ params -= p_fn() * 2
49
+ h = h**2 # because perturbation already multiplied by h
50
+ return v_0, v_0, (-3*v_0 + 4*v_plus1 - v_plus2) / (2 * h)
51
+
52
+ def _rbackward3(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
53
+ if v_0 is None: v_0 = closure(False)
54
+
55
+ params -= p_fn()
56
+ v_minus1 = closure(False)
57
+
58
+ params -= p_fn()
59
+ v_minus2 = closure(False)
60
+
61
+ params += p_fn() * 2
62
+ h = h**2 # because perturbation already multiplied by h
63
+ return v_0, v_0, (v_minus2 - 4*v_minus1 + 3*v_0) / (2 * h)
64
+
65
+ def _rcentral4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, v_0: float | None):
66
+ params += p_fn()
67
+ v_plus1 = closure(False)
68
+
69
+ params += p_fn()
70
+ v_plus2 = closure(False)
71
+
72
+ params -= p_fn() * 3
73
+ v_minus1 = closure(False)
74
+
75
+ params -= p_fn()
76
+ v_minus2 = closure(False)
77
+
78
+ params += p_fn() * 2
79
+ h = h**2 # because perturbation already multiplied by h
80
+ return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
81
+
82
+ _RFD_FUNCS = {
83
+ "forward2": _rforward2,
84
+ "backward2": _rbackward2,
85
+ "central2": _rcentral2,
86
+ "forward3": _rforward3,
87
+ "backward3": _rbackward3,
88
+ "central4": _rcentral4,
89
+ }
90
+
91
+
92
+ class RandomizedFDM(GradApproximator):
93
+ PRE_MULTIPLY_BY_H = True
94
+ def __init__(
95
+ self,
96
+ h: float = 1e-3,
97
+ n_samples: int = 1,
98
+ formula: _FD_Formula = "central2",
99
+ distribution: Distributions = "rademacher",
100
+ beta: float = 0,
101
+ pre_generate = True,
102
+ target: GradTarget = "closure",
103
+ seed: int | None | torch.Generator = None,
104
+ ):
105
+ defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution, beta=beta, pre_generate=pre_generate, seed=seed)
106
+ super().__init__(defaults, target=target)
107
+
108
+ def reset(self):
109
+ self.state.clear()
110
+ generator = self.global_state.get('generator', None) # avoid resetting generator
111
+ self.global_state.clear()
112
+ if generator is not None: self.global_state['generator'] = generator
113
+
114
+ def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
115
+ if 'generator' not in self.global_state:
116
+ if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
117
+ elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
118
+ else: self.global_state['generator'] = None
119
+ return self.global_state['generator']
120
+
121
+ def pre_step(self, vars):
122
+ h, beta = self.get_settings('h', 'beta', params=vars.params)
123
+ settings = self.settings[vars.params[0]]
124
+ n_samples = settings['n_samples']
125
+ distribution = settings['distribution']
126
+ pre_generate = settings['pre_generate']
127
+
128
+ if pre_generate:
129
+ params = TensorList(vars.params)
130
+ generator = self._get_generator(settings['seed'], vars.params)
131
+ perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
132
+
133
+ if self.PRE_MULTIPLY_BY_H:
134
+ torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
135
+
136
+ if all(i==0 for i in beta):
137
+ # just use pre-generated perturbations
138
+ for param, prt in zip(params, zip(*perturbations)):
139
+ self.state[param]['perturbations'] = prt
140
+
141
+ else:
142
+ # lerp old and new perturbations. This makes the subspace change gradually
143
+ # which in theory might improve algorithms with history
144
+ for i,p in enumerate(params):
145
+ state = self.state[p]
146
+ if 'perturbations' not in state: state['perturbations'] = [p[i] for p in perturbations]
147
+
148
+ cur = [self.state[p]['perturbations'][:n_samples] for p in params]
149
+ cur_flat = [p for l in cur for p in l]
150
+ new_flat = [p for l in zip(*perturbations) for p in l]
151
+ betas = [1-v for b in beta for v in [b]*n_samples]
152
+ torch._foreach_lerp_(cur_flat, new_flat, betas)
153
+
154
+ @torch.no_grad
155
+ def approximate(self, closure, params, loss, vars):
156
+ params = TensorList(params)
157
+ loss_approx = None
158
+
159
+ h = self.get_settings('h', params=vars.params, cls=NumberList)
160
+ settings = self.settings[params[0]]
161
+ n_samples = settings['n_samples']
162
+ fd_fn = _RFD_FUNCS[settings['formula']]
163
+ default = [None]*n_samples
164
+ perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
165
+ distribution = settings['distribution']
166
+ generator = self._get_generator(settings['seed'], params)
167
+
168
+ grad = None
169
+ for i in range(n_samples):
170
+ prt = perturbations[i]
171
+ if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
172
+ else: prt = TensorList(prt)
173
+
174
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, v_0=loss)
175
+ if grad is None: grad = prt * d
176
+ else: grad += prt * d
177
+
178
+ assert grad is not None
179
+ if n_samples > 1: grad.div_(n_samples)
180
+ return grad, loss, loss_approx
181
+
182
+ SPSA = RandomizedFDM
183
+
184
+ class RDSA(RandomizedFDM):
185
+ def __init__(
186
+ self,
187
+ h: float = 1e-3,
188
+ n_samples: int = 1,
189
+ formula: _FD_Formula = "central2",
190
+ distribution: Distributions = "gaussian",
191
+ beta: float = 0,
192
+ pre_generate = True,
193
+ target: GradTarget = "closure",
194
+ seed: int | None | torch.Generator = None,
195
+ ):
196
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
197
+
198
+ class GaussianSmoothing(RandomizedFDM):
199
+ def __init__(
200
+ self,
201
+ h: float = 1e-2,
202
+ n_samples: int = 100,
203
+ formula: _FD_Formula = "central2",
204
+ distribution: Distributions = "gaussian",
205
+ beta: float = 0,
206
+ pre_generate = True,
207
+ target: GradTarget = "closure",
208
+ seed: int | None | torch.Generator = None,
209
+ ):
210
+ super().__init__(h=h, n_samples=n_samples,formula=formula,distribution=distribution,beta=beta,pre_generate=pre_generate,target=target,seed=seed)
211
+
212
+ class MeZO(GradApproximator):
213
+ def __init__(self, h: float=1e-3, n_samples: int = 1, formula: _FD_Formula = 'central2',
214
+ distribution: Distributions = 'rademacher', target: GradTarget = 'closure'):
215
+ defaults = dict(h=h, formula=formula, n_samples=n_samples, distribution=distribution)
216
+ super().__init__(defaults, target=target)
217
+
218
+ def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
219
+ return TensorList(params).sample_like(
220
+ distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
221
+ ).mul_(h)
222
+
223
+ def pre_step(self, vars):
224
+ h = self.get_settings('h', params=vars.params)
225
+ settings = self.settings[vars.params[0]]
226
+ n_samples = settings['n_samples']
227
+ distribution = settings['distribution']
228
+
229
+ step = vars.current_step
230
+
231
+ # create functions that generate a deterministic perturbation from seed based on current step
232
+ prt_fns = []
233
+ for i in range(n_samples):
234
+
235
+ prt_fn = partial(self._seeded_perturbation, params=vars.params, distribution=distribution, seed=1_000_000*step + i, h=h)
236
+ prt_fns.append(prt_fn)
237
+
238
+ self.global_state['prt_fns'] = prt_fns
239
+
240
+ @torch.no_grad
241
+ def approximate(self, closure, params, loss, vars):
242
+ params = TensorList(params)
243
+ loss_approx = None
244
+
245
+ h = self.get_settings('h', params=vars.params, cls=NumberList)
246
+ settings = self.settings[params[0]]
247
+ n_samples = settings['n_samples']
248
+ fd_fn = _RFD_FUNCS[settings['formula']]
249
+ prt_fns = self.global_state['prt_fns']
250
+
251
+ grad = None
252
+ for i in range(n_samples):
253
+ loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=prt_fns[i], h=h, v_0=loss)
254
+ if grad is None: grad = prt_fns[i]().mul_(d)
255
+ else: grad += prt_fns[i]().mul_(d)
256
+
257
+ assert grad is not None
258
+ if n_samples > 1: grad.div_(n_samples)
259
+ return grad, loss, loss_approx
@@ -1,30 +1,5 @@
1
- r"""
2
- Line searches.
3
- """
4
-
5
- from typing import Literal
6
-
7
- from ...core import OptimizerModule
8
- from ..regularization import Normalize
9
- from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS,
10
- MultiplicativeLS)
11
- # from .quad_interp import QuadraticInterpolation2Point
12
- from .directional_newton import DirectionalNewton3Points, DirectionalNewton
13
- from .scipy_minimize_scalar import ScipyMinimizeScalarLS
14
- from .armijo import ArmijoLS
15
-
16
- LineSearches = Literal['backtracking', 'brent', 'brent-exact', 'brent-norm', 'multiplicative', 'newton', 'newton3', 'armijo'] | OptimizerModule
17
-
18
- def get_line_search(name:str | OptimizerModule) -> OptimizerModule | list[OptimizerModule]:
19
- if isinstance(name, str):
20
- name = name.strip().lower()
21
- if name == 'backtracking': return BacktrackingLS()
22
- if name == 'multiplicative': return MultiplicativeLS()
23
- if name == 'brent': return ScipyMinimizeScalarLS(maxiter=8)
24
- if name == 'brent-exact': return ScipyMinimizeScalarLS()
25
- if name == 'brent-norm': return [Normalize(), ScipyMinimizeScalarLS(maxiter=16)]
26
- if name == 'newton': return DirectionalNewton(1)
27
- if name == 'newton3': return DirectionalNewton3Points(1)
28
- if name == 'armijo': return ArmijoLS(1)
29
- raise ValueError(f"Unknown line search method: {name}")
30
- return name
1
+ from .line_search import LineSearch, GridLineSearch
2
+ from .backtracking import backtracking_line_search, Backtracking, AdaptiveBacktracking
3
+ from .strong_wolfe import StrongWolfe
4
+ from .scipy import ScipyMinimizeScalar
5
+ from .trust_region import TrustRegion
@@ -0,0 +1,186 @@
1
+ import math
2
+ from collections.abc import Callable
3
+ from operator import itemgetter
4
+
5
+ import torch
6
+
7
+ from .line_search import LineSearch
8
+
9
+
10
+ def backtracking_line_search(
11
+ f: Callable[[float], float],
12
+ g_0: float | torch.Tensor,
13
+ init: float = 1.0,
14
+ beta: float = 0.5,
15
+ c: float = 1e-4,
16
+ maxiter: int = 10,
17
+ a_min: float | None = None,
18
+ try_negative: bool = False,
19
+ ) -> float | None:
20
+ """
21
+
22
+ Args:
23
+ objective_fn: evaluates step size along some descent direction.
24
+ dir_derivative: directional derivative along the descent direction.
25
+ alpha_init: initial step size.
26
+ beta: The factor by which to decrease alpha in each iteration
27
+ c: The constant for the Armijo sufficient decrease condition
28
+ max_iter: Maximum number of backtracking iterations (default: 10).
29
+ min_alpha: Minimum allowable step size to prevent near-zero values (default: 1e-16).
30
+
31
+ Returns:
32
+ step size
33
+ """
34
+
35
+ a = init
36
+ f_x = f(0)
37
+
38
+ for iteration in range(maxiter):
39
+ f_a = f(a)
40
+
41
+ if f_a <= f_x + c * a * min(g_0, 0): # pyright: ignore[reportArgumentType]
42
+ # found an acceptable alpha
43
+ return a
44
+
45
+ # decrease alpha
46
+ a *= beta
47
+
48
+ # alpha too small
49
+ if a_min is not None and a < a_min:
50
+ return a_min
51
+
52
+ # fail
53
+ if try_negative:
54
+ def inv_objective(alpha): return f(-alpha)
55
+
56
+ v = backtracking_line_search(
57
+ inv_objective,
58
+ g_0=-g_0,
59
+ beta=beta,
60
+ c=c,
61
+ maxiter=maxiter,
62
+ a_min=a_min,
63
+ try_negative=False,
64
+ )
65
+ if v is not None: return -v
66
+
67
+ return None
68
+
69
+ class Backtracking(LineSearch):
70
+ def __init__(
71
+ self,
72
+ init: float = 1.0,
73
+ beta: float = 0.5,
74
+ c: float = 1e-4,
75
+ maxiter: int = 10,
76
+ min_alpha: float | None = None,
77
+ adaptive=True,
78
+ try_negative: bool = False,
79
+ ):
80
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,adaptive=adaptive, try_negative=try_negative)
81
+ super().__init__(defaults=defaults)
82
+ self.global_state['beta_scale'] = 1.0
83
+
84
+ def reset(self):
85
+ super().reset()
86
+ self.global_state['beta_scale'] = 1.0
87
+
88
+ @torch.no_grad
89
+ def search(self, update, vars):
90
+ init, beta, c, maxiter, min_alpha, adaptive, try_negative = itemgetter(
91
+ 'init', 'beta', 'c', 'maxiter', 'min_alpha', 'adaptive', 'try_negative')(self.settings[vars.params[0]])
92
+
93
+ objective = self.make_objective(vars=vars)
94
+
95
+ # # directional derivative
96
+ d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), vars.get_update()))
97
+
98
+ # scale beta (beta is multiplicative and i think may be better than scaling initial step size)
99
+ if adaptive: beta = beta * self.global_state['beta_scale']
100
+
101
+ step_size = backtracking_line_search(objective, d, init=init,beta=beta,
102
+ c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
103
+
104
+ # found an alpha that reduces loss
105
+ if step_size is not None:
106
+ self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
107
+ return step_size
108
+
109
+ # on fail reduce beta scale value
110
+ self.global_state['beta_scale'] /= 1.5
111
+ return 0
112
+
113
+ def _lerp(start,end,weight):
114
+ return start + weight * (end - start)
115
+
116
+ class AdaptiveBacktracking(LineSearch):
117
+ def __init__(
118
+ self,
119
+ init: float = 1.0,
120
+ beta: float = 0.5,
121
+ c: float = 1e-4,
122
+ maxiter: int = 20,
123
+ min_alpha: float | None = None,
124
+ target_iters = 1,
125
+ nplus = 2.0,
126
+ scale_beta = 0.0,
127
+ try_negative: bool = False,
128
+ ):
129
+ defaults=dict(init=init,beta=beta,c=c,maxiter=maxiter,min_alpha=min_alpha,target_iters=target_iters,nplus=nplus,scale_beta=scale_beta, try_negative=try_negative)
130
+ super().__init__(defaults=defaults)
131
+
132
+ self.global_state['beta_scale'] = 1.0
133
+ self.global_state['initial_scale'] = 1.0
134
+
135
+ def reset(self):
136
+ super().reset()
137
+ self.global_state['beta_scale'] = 1.0
138
+ self.global_state['initial_scale'] = 1.0
139
+
140
+ @torch.no_grad
141
+ def search(self, update, vars):
142
+ init, beta, c, maxiter, min_alpha, target_iters, nplus, scale_beta, try_negative=itemgetter(
143
+ 'init','beta','c','maxiter','min_alpha','target_iters','nplus','scale_beta', 'try_negative')(self.settings[vars.params[0]])
144
+
145
+ objective = self.make_objective(vars=vars)
146
+
147
+ # directional derivative (0 if c = 0 because it is not needed)
148
+ if c == 0: d = 0
149
+ else: d = -sum(t.sum() for t in torch._foreach_mul(vars.get_grad(), update))
150
+
151
+ # scale beta
152
+ beta = beta * self.global_state['beta_scale']
153
+
154
+ # scale step size so that decrease is expected at target_iters
155
+ init = init * self.global_state['initial_scale']
156
+
157
+ step_size = backtracking_line_search(objective, d, init=init, beta=beta,
158
+ c=c,maxiter=maxiter,a_min=min_alpha, try_negative=try_negative)
159
+
160
+ # found an alpha that reduces loss
161
+ if step_size is not None:
162
+
163
+ # update initial_scale
164
+ # initial step size satisfied conditions, increase initial_scale by nplus
165
+ if step_size == init and target_iters > 0:
166
+ self.global_state['initial_scale'] *= nplus ** target_iters
167
+ self.global_state['initial_scale'] = min(self.global_state['initial_scale'], 1e32) # avoid overflow error
168
+
169
+ else:
170
+ # otherwise make initial_scale such that target_iters iterations will satisfy armijo
171
+ init_target = step_size
172
+ for _ in range(target_iters):
173
+ init_target = step_size / beta
174
+
175
+ self.global_state['initial_scale'] = _lerp(
176
+ self.global_state['initial_scale'], init_target / init, 1-scale_beta
177
+ )
178
+
179
+ # revert beta_scale
180
+ self.global_state['beta_scale'] = min(1.0, self.global_state['beta_scale'] * math.sqrt(1.5))
181
+
182
+ return step_size
183
+
184
+ # on fail reduce beta scale value
185
+ self.global_state['beta_scale'] /= 1.5
186
+ return 0