torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -9,9 +9,17 @@ Additional functional variants are present in most module files, e.g. `adam_`, `
9
9
  """
10
10
  from collections.abc import Callable
11
11
  from typing import overload
12
+
12
13
  import torch
13
14
 
14
- from ..utils import NumberList, TensorList
15
+ from ..utils import (
16
+ NumberList,
17
+ TensorList,
18
+ generic_finfo_eps,
19
+ generic_max,
20
+ generic_sum,
21
+ tofloat,
22
+ )
15
23
 
16
24
  inf = float('inf')
17
25
 
@@ -87,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
87
95
  if p == 1: return tensors_.abs_()
88
96
  if p == 2: return tensors_.sqrt_()
89
97
  return tensors_.pow_(1/p)
90
- else:
91
- if p == 1: return tensors_.abs()
92
- if p == 2: return tensors_.sqrt()
93
- return tensors_.pow(1/p)
98
+
99
+ if p == 1: return tensors_.abs()
100
+ if p == 2: return tensors_.sqrt()
101
+ return tensors_.pow(1/p)
94
102
 
95
103
 
96
104
  def ema_(
@@ -207,13 +215,41 @@ def sqrt_centered_ema_sq_(
207
215
  ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
208
216
  )
209
217
 
210
- @overload
211
- def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
212
- @overload
213
- def safe_scaling_(tensors_: TensorList) -> TensorList: ...
214
- def safe_scaling_(tensors_: torch.Tensor | TensorList):
215
- if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
216
- else: scale = 1 / tensors_.abs().global_sum()
217
- scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
218
- return tensors_.mul_(scale)
218
+ def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
219
+ """initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
220
+ this version is safer and makes sure largest value isn't smaller than epsilon."""
221
+ tensors_abs = tensors.abs()
222
+ tensors_sum = generic_sum(tensors_abs)
223
+ tensors_max = generic_max(tensors_abs)
224
+
225
+ feps = generic_finfo_eps(tensors)
226
+ if eps is None: eps = feps
227
+ else: eps = max(eps, feps)
228
+
229
+ # scale should not make largest value smaller than epsilon
230
+ min = eps / tensors_max
231
+ if min >= 1: return 1.0
232
+
233
+ scale = 1 / tensors_sum
234
+ scale = scale.clip(min=min.item(), max=1)
235
+ return scale.item()
236
+
237
+
238
+ def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
239
+ """makes sure largest value isn't smaller than epsilon."""
240
+ tensors_abs = tensors.abs()
241
+ tensors_max = generic_max(tensors_abs)
242
+ if tensors_max < alpha: return 1.0
243
+
244
+ if tensors_max < 1: alpha = alpha / tensors_max
245
+ return tofloat(alpha)
246
+
247
+
248
+
249
+ def safe_clip(x: torch.Tensor, min=None):
250
+ """makes sure absolute value of scalar tensor x is not smaller than min"""
251
+ assert x.numel() == 1, x.shape
252
+ if min is None: min = torch.finfo(x.dtype).tiny * 2
219
253
 
254
+ if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
255
+ return x
@@ -93,7 +93,7 @@ _FD_FUNCS = {
93
93
  class FDM(GradApproximator):
94
94
  """Approximate gradients via finite difference method.
95
95
 
96
- .. note::
96
+ Note:
97
97
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
98
98
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
99
99
 
@@ -103,24 +103,23 @@ class FDM(GradApproximator):
103
103
  target (GradTarget, optional): what to set on var. Defaults to 'closure'.
104
104
 
105
105
  Examples:
106
- plain FDM:
107
-
108
- .. code-block:: python
109
-
110
- fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
111
-
112
- Any gradient-based method can use FDM-estimated gradients seamlessly.
113
-
114
- .. code-block:: python
115
-
116
- fdm_ncg = tz.Modular(
117
- model.parameters(),
118
- tz.m.FDM(),
119
- # set hvp_method to "forward" so that it
120
- # uses gradient difference instead of autograd
121
- tz.m.NewtonCG(hvp_method="forward"),
122
- tz.m.Backtracking()
123
- )
106
+ plain FDM:
107
+
108
+ ```python
109
+ fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
110
+ ```
111
+
112
+ Any gradient-based method can use FDM-estimated gradients.
113
+ ```python
114
+ fdm_ncg = tz.Modular(
115
+ model.parameters(),
116
+ tz.m.FDM(),
117
+ # set hvp_method to "forward" so that it
118
+ # uses gradient difference instead of autograd
119
+ tz.m.NewtonCG(hvp_method="forward"),
120
+ tz.m.Backtracking()
121
+ )
122
+ ```
124
123
  """
125
124
  def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
126
125
  defaults = dict(h=h, formula=formula)
@@ -139,7 +138,7 @@ class FDM(GradApproximator):
139
138
  h = settings['h']
140
139
  fd_fn = _FD_FUNCS[settings['formula']]
141
140
 
142
- p_flat = p.view(-1); g_flat = g.view(-1)
141
+ p_flat = p.ravel(); g_flat = g.ravel()
143
142
  for i in range(len(p_flat)):
144
143
  loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
145
144
  g_flat[i] = d
@@ -15,7 +15,7 @@ class ForwardGradient(RandomizedFDM):
15
15
 
16
16
  This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
17
17
 
18
- .. note::
18
+ Note:
19
19
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
20
20
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
21
21
 
@@ -67,7 +67,9 @@ class ForwardGradient(RandomizedFDM):
67
67
  grad = None
68
68
  for i in range(n_samples):
69
69
  prt = perturbations[i]
70
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator)
70
+ if prt[0] is None:
71
+ prt = params.sample_like(distribution=distribution, variance=1, generator=generator)
72
+
71
73
  else: prt = TensorList(prt)
72
74
 
73
75
  if jvp_method == 'autograd':
@@ -24,63 +24,59 @@ class GradApproximator(Module, ABC):
24
24
 
25
25
  Example:
26
26
 
27
- Basic SPSA method implementation.
28
-
29
- .. code-block:: python
30
-
31
- class SPSA(GradApproximator):
32
- def __init__(self, h=1e-3):
33
- defaults = dict(h=h)
34
- super().__init__(defaults)
35
-
36
- @torch.no_grad
37
- def approximate(self, closure, params, loss):
38
- perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
39
-
40
- # evaluate params + perturbation
41
- torch._foreach_add_(params, perturbation)
42
- loss_plus = closure(False)
43
-
44
- # evaluate params - perturbation
45
- torch._foreach_sub_(params, perturbation)
46
- torch._foreach_sub_(params, perturbation)
47
- loss_minus = closure(False)
48
-
49
- # restore original params
50
- torch._foreach_add_(params, perturbation)
51
-
52
- # calculate SPSA gradients
53
- spsa_grads = []
54
- for p, pert in zip(params, perturbation):
55
- settings = self.settings[p]
56
- h = settings['h']
57
- d = (loss_plus - loss_minus) / (2*(h**2))
58
- spsa_grads.append(pert * d)
59
-
60
- # returns tuple: (grads, loss, loss_approx)
61
- # loss must be with initial parameters
62
- # since we only evaluated loss with perturbed parameters
63
- # we only have loss_approx
64
- return spsa_grads, None, loss_plus
65
-
66
- """
27
+ Basic SPSA method implementation.
28
+ ```python
29
+ class SPSA(GradApproximator):
30
+ def __init__(self, h=1e-3):
31
+ defaults = dict(h=h)
32
+ super().__init__(defaults)
33
+
34
+ @torch.no_grad
35
+ def approximate(self, closure, params, loss):
36
+ perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
37
+
38
+ # evaluate params + perturbation
39
+ torch._foreach_add_(params, perturbation)
40
+ loss_plus = closure(False)
41
+
42
+ # evaluate params - perturbation
43
+ torch._foreach_sub_(params, perturbation)
44
+ torch._foreach_sub_(params, perturbation)
45
+ loss_minus = closure(False)
46
+
47
+ # restore original params
48
+ torch._foreach_add_(params, perturbation)
49
+
50
+ # calculate SPSA gradients
51
+ spsa_grads = []
52
+ for p, pert in zip(params, perturbation):
53
+ settings = self.settings[p]
54
+ h = settings['h']
55
+ d = (loss_plus - loss_minus) / (2*(h**2))
56
+ spsa_grads.append(pert * d)
57
+
58
+ # returns tuple: (grads, loss, loss_approx)
59
+ # loss must be with initial parameters
60
+ # since we only evaluated loss with perturbed parameters
61
+ # we only have loss_approx
62
+ return spsa_grads, None, loss_plus
63
+ ```
64
+ """
67
65
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
68
66
  super().__init__(defaults)
69
67
  self._target: GradTarget = target
70
68
 
71
69
  @abstractmethod
72
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
73
- """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
70
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
71
+ """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
74
72
 
75
- def pre_step(self, var: Var) -> Var | None:
73
+ def pre_step(self, var: Var) -> None:
76
74
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
77
75
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
78
- return var
79
76
 
80
77
  @torch.no_grad
81
78
  def step(self, var):
82
- ret = self.pre_step(var)
83
- if isinstance(ret, Var): var = ret
79
+ self.pre_step(var)
84
80
 
85
81
  if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
86
82
  params, closure, loss = var.params, var.closure, var.loss
@@ -108,4 +104,4 @@ class GradApproximator(Module, ABC):
108
104
  else: raise ValueError(self._target)
109
105
  return var
110
106
 
111
- _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']
107
+ _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']
@@ -115,26 +115,26 @@ def _rforward5(closure: Callable[..., float], params:TensorList, p_fn:Callable[[
115
115
  h = h**2 # because perturbation already multiplied by h
116
116
  return f_0, f_0, (-3*f_4 + 16*f_3 - 36*f_2 + 48*f_1 - 25*f_0) / (12 * h)
117
117
 
118
- # another central4
119
- def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
- params += p_fn()
121
- f_1 = closure(False)
118
+ # # another central4
119
+ # def _bgspsa4(closure: Callable[..., float], params:TensorList, p_fn:Callable[[], TensorList], h, f_0: float | None):
120
+ # params += p_fn()
121
+ # f_1 = closure(False)
122
122
 
123
- params += p_fn() * 2
124
- f_3 = closure(False)
123
+ # params += p_fn() * 2
124
+ # f_3 = closure(False)
125
125
 
126
- params -= p_fn() * 4
127
- f_m1 = closure(False)
126
+ # params -= p_fn() * 4
127
+ # f_m1 = closure(False)
128
128
 
129
- params -= p_fn() * 2
130
- f_m3 = closure(False)
129
+ # params -= p_fn() * 2
130
+ # f_m3 = closure(False)
131
131
 
132
- params += p_fn() * 3
133
- h = h**2 # because perturbation already multiplied by h
134
- return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
132
+ # params += p_fn() * 3
133
+ # h = h**2 # because perturbation already multiplied by h
134
+ # return f_0, f_1, (27*f_1 - f_m1 - f_3 + f_m3) / (48 * h)
135
135
 
136
136
 
137
- _RFD_FUNCS = {
137
+ _RFD_FUNCS: dict[_FD_Formula, Callable] = {
138
138
  "forward": _rforward2,
139
139
  "forward2": _rforward2,
140
140
  "backward": _rbackward2,
@@ -147,14 +147,14 @@ _RFD_FUNCS = {
147
147
  "central4": _rcentral4,
148
148
  "forward4": _rforward4,
149
149
  "forward5": _rforward5,
150
- "bspsa4": _bgspsa4,
150
+ # "bspsa4": _bgspsa4,
151
151
  }
152
152
 
153
153
 
154
154
  class RandomizedFDM(GradApproximator):
155
155
  """Gradient approximation via a randomized finite-difference method.
156
156
 
157
- .. note::
157
+ Note:
158
158
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
159
159
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
160
160
 
@@ -171,94 +171,95 @@ class RandomizedFDM(GradApproximator):
171
171
  target (GradTarget, optional): what to set on var. Defaults to "closure".
172
172
 
173
173
  Examples:
174
- #### Simultaneous perturbation stochastic approximation (SPSA) method
175
-
176
- SPSA is randomized finite differnce with rademacher distribution and central formula.
177
-
178
- .. code-block:: python
179
-
180
- spsa = tz.Modular(
181
- model.parameters(),
182
- tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
183
- tz.m.LR(1e-2)
184
- )
185
-
186
- #### Random-direction stochastic approximation (RDSA) method
187
-
188
- RDSA is randomized finite differnce with usually gaussian distribution and central formula.
189
-
190
- .. code-block:: python
191
-
192
- rdsa = tz.Modular(
193
- model.parameters(),
194
- tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
195
- tz.m.LR(1e-2)
196
- )
197
-
198
- #### RandomizedFDM with momentum
199
-
200
- Momentum might help by reducing the variance of the estimated gradients.
201
-
202
- .. code-block:: python
203
-
204
- momentum_spsa = tz.Modular(
205
- model.parameters(),
206
- tz.m.RandomizedFDM(),
207
- tz.m.HeavyBall(0.9),
208
- tz.m.LR(1e-3)
209
- )
210
-
211
- #### Gaussian smoothing method
212
-
213
- GS uses many gaussian samples with possibly a larger finite difference step size.
214
-
215
- .. code-block:: python
216
-
217
- gs = tz.Modular(
218
- model.parameters(),
219
- tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
220
- tz.m.NewtonCG(hvp_method="forward"),
221
- tz.m.Backtracking()
222
- )
223
-
224
- #### SPSA-NewtonCG
225
-
226
- NewtonCG with hessian-vector product estimated via gradient difference
227
- calls closure multiple times per step. If each closure call estimates gradients
228
- with different perturbations, NewtonCG is unable to produce useful directions.
229
-
230
- By setting pre_generate to True, perturbations are generated once before each step,
231
- and each closure call estimates gradients using the same pre-generated perturbations.
232
- This way closure-based algorithms are able to use gradients estimated in a consistent way.
233
-
234
- .. code-block:: python
235
-
236
- opt = tz.Modular(
237
- model.parameters(),
238
- tz.m.RandomizedFDM(n_samples=10),
239
- tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
240
- tz.m.Backtracking()
241
- )
242
-
243
- #### SPSA-BFGS
244
-
245
- L-BFGS uses a memory of past parameter and gradient differences. If past gradients
246
- were estimated with different perturbations, L-BFGS directions will be useless.
247
-
248
- To alleviate this momentum can be added to random perturbations to make sure they only
249
- change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
250
- The disadvantage is that the subspace the algorithm is able to explore changes slowly.
251
-
252
- Additionally we will reset BFGS memory every 100 steps to remove influence from old gradient estimates.
253
-
254
- .. code-block:: python
255
-
256
- opt = tz.Modular(
257
- model.parameters(),
258
- tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99),
259
- tz.m.BFGS(reset_interval=100),
260
- tz.m.Backtracking()
261
- )
174
+ #### Simultaneous perturbation stochastic approximation (SPSA) method
175
+
176
+ SPSA is randomized finite differnce with rademacher distribution and central formula.
177
+ ```py
178
+ spsa = tz.Modular(
179
+ model.parameters(),
180
+ tz.m.RandomizedFDM(formula="central", distribution="rademacher"),
181
+ tz.m.LR(1e-2)
182
+ )
183
+ ```
184
+
185
+ #### Random-direction stochastic approximation (RDSA) method
186
+
187
+ RDSA is randomized finite differnce with usually gaussian distribution and central formula.
188
+
189
+ ```
190
+ rdsa = tz.Modular(
191
+ model.parameters(),
192
+ tz.m.RandomizedFDM(formula="central", distribution="gaussian"),
193
+ tz.m.LR(1e-2)
194
+ )
195
+ ```
196
+
197
+ #### RandomizedFDM with momentum
198
+
199
+ Momentum might help by reducing the variance of the estimated gradients.
200
+
201
+ ```
202
+ momentum_spsa = tz.Modular(
203
+ model.parameters(),
204
+ tz.m.RandomizedFDM(),
205
+ tz.m.HeavyBall(0.9),
206
+ tz.m.LR(1e-3)
207
+ )
208
+ ```
209
+
210
+ #### Gaussian smoothing method
211
+
212
+ GS uses many gaussian samples with possibly a larger finite difference step size.
213
+
214
+ ```
215
+ gs = tz.Modular(
216
+ model.parameters(),
217
+ tz.m.RandomizedFDM(n_samples=100, distribution="gaussian", formula="forward2", h=1e-1),
218
+ tz.m.NewtonCG(hvp_method="forward"),
219
+ tz.m.Backtracking()
220
+ )
221
+ ```
222
+
223
+ #### SPSA-NewtonCG
224
+
225
+ NewtonCG with hessian-vector product estimated via gradient difference
226
+ calls closure multiple times per step. If each closure call estimates gradients
227
+ with different perturbations, NewtonCG is unable to produce useful directions.
228
+
229
+ By setting pre_generate to True, perturbations are generated once before each step,
230
+ and each closure call estimates gradients using the same pre-generated perturbations.
231
+ This way closure-based algorithms are able to use gradients estimated in a consistent way.
232
+
233
+ ```
234
+ opt = tz.Modular(
235
+ model.parameters(),
236
+ tz.m.RandomizedFDM(n_samples=10),
237
+ tz.m.NewtonCG(hvp_method="forward", pre_generate=True),
238
+ tz.m.Backtracking()
239
+ )
240
+ ```
241
+
242
+ #### SPSA-LBFGS
243
+
244
+ LBFGS uses a memory of past parameter and gradient differences. If past gradients
245
+ were estimated with different perturbations, LBFGS directions will be useless.
246
+
247
+ To alleviate this momentum can be added to random perturbations to make sure they only
248
+ change by a little bit, and the history stays relevant. The momentum is determined by the :code:`beta` parameter.
249
+ The disadvantage is that the subspace the algorithm is able to explore changes slowly.
250
+
251
+ Additionally we will reset SPSA and LBFGS memory every 100 steps to remove influence from old gradient estimates.
252
+
253
+ ```
254
+ opt = tz.Modular(
255
+ bench.parameters(),
256
+ tz.m.ResetEvery(
257
+ [tz.m.RandomizedFDM(n_samples=10, pre_generate=True, beta=0.99), tz.m.LBFGS()],
258
+ steps = 100,
259
+ ),
260
+ tz.m.Backtracking()
261
+ )
262
+ ```
262
263
  """
263
264
  PRE_MULTIPLY_BY_H = True
264
265
  def __init__(
@@ -280,6 +281,7 @@ class RandomizedFDM(GradApproximator):
280
281
  generator = self.global_state.get('generator', None) # avoid resetting generator
281
282
  self.global_state.clear()
282
283
  if generator is not None: self.global_state['generator'] = generator
284
+ for c in self.children.values(): c.reset()
283
285
 
284
286
  def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
285
287
  if 'generator' not in self.global_state:
@@ -290,15 +292,15 @@ class RandomizedFDM(GradApproximator):
290
292
 
291
293
  def pre_step(self, var):
292
294
  h, beta = self.get_settings(var.params, 'h', 'beta')
293
- settings = self.settings[var.params[0]]
294
- n_samples = settings['n_samples']
295
- distribution = settings['distribution']
296
- pre_generate = settings['pre_generate']
295
+
296
+ n_samples = self.defaults['n_samples']
297
+ distribution = self.defaults['distribution']
298
+ pre_generate = self.defaults['pre_generate']
297
299
 
298
300
  if pre_generate:
299
301
  params = TensorList(var.params)
300
- generator = self._get_generator(settings['seed'], var.params)
301
- perturbations = [params.sample_like(distribution=distribution, generator=generator) for _ in range(n_samples)]
302
+ generator = self._get_generator(self.defaults['seed'], var.params)
303
+ perturbations = [params.sample_like(distribution=distribution, variance=1, generator=generator) for _ in range(n_samples)]
302
304
 
303
305
  if self.PRE_MULTIPLY_BY_H:
304
306
  torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
@@ -339,27 +341,44 @@ class RandomizedFDM(GradApproximator):
339
341
  grad = None
340
342
  for i in range(n_samples):
341
343
  prt = perturbations[i]
342
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator).mul_(h)
344
+
345
+ if prt[0] is None:
346
+ prt = params.sample_like(distribution=distribution, generator=generator, variance=1).mul_(h)
347
+
343
348
  else: prt = TensorList(prt)
344
349
 
345
350
  loss, loss_approx, d = fd_fn(closure=closure, params=params, p_fn=lambda: prt, h=h, f_0=loss)
351
+ # here `d` is a numberlist of directional derivatives, due to per parameter `h` values.
352
+
353
+ # support for per-sample values which gives better estimate
354
+ if d[0].numel() > 1: d = d.map(torch.mean)
355
+
346
356
  if grad is None: grad = prt * d
347
357
  else: grad += prt * d
348
358
 
349
359
  params.set_(orig_params)
350
360
  assert grad is not None
351
361
  if n_samples > 1: grad.div_(n_samples)
362
+
363
+ # mean if got per-sample values
364
+ if loss is not None:
365
+ if loss.numel() > 1:
366
+ loss = loss.mean()
367
+
368
+ if loss_approx is not None:
369
+ if loss_approx.numel() > 1:
370
+ loss_approx = loss_approx.mean()
371
+
352
372
  return grad, loss, loss_approx
353
373
 
354
374
  class SPSA(RandomizedFDM):
355
375
  """
356
376
  Gradient approximation via Simultaneous perturbation stochastic approximation (SPSA) method.
357
377
 
358
- .. note::
378
+ Note:
359
379
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
360
380
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
361
381
 
362
-
363
382
  Args:
364
383
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
365
384
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
@@ -380,7 +399,7 @@ class RDSA(RandomizedFDM):
380
399
  """
381
400
  Gradient approximation via Random-direction stochastic approximation (RDSA) method.
382
401
 
383
- .. note::
402
+ Note:
384
403
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
385
404
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
386
405
 
@@ -417,7 +436,7 @@ class GaussianSmoothing(RandomizedFDM):
417
436
  """
418
437
  Gradient approximation via Gaussian smoothing method.
419
438
 
420
- .. note::
439
+ Note:
421
440
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
422
441
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
423
442
 
@@ -453,7 +472,7 @@ class GaussianSmoothing(RandomizedFDM):
453
472
  class MeZO(GradApproximator):
454
473
  """Gradient approximation via memory-efficient zeroth order optimizer (MeZO) - https://arxiv.org/abs/2305.17333.
455
474
 
456
- .. note::
475
+ Note:
457
476
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
458
477
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
459
478
 
@@ -476,15 +495,18 @@ class MeZO(GradApproximator):
476
495
  super().__init__(defaults, target=target)
477
496
 
478
497
  def _seeded_perturbation(self, params: list[torch.Tensor], distribution, seed, h):
479
- return TensorList(params).sample_like(
480
- distribution=distribution, generator=torch.Generator(params[0].device).manual_seed(seed)
481
- ).mul_(h)
498
+ prt = TensorList(params).sample_like(
499
+ distribution=distribution,
500
+ variance=h,
501
+ generator=torch.Generator(params[0].device).manual_seed(seed)
502
+ )
503
+ return prt
482
504
 
483
505
  def pre_step(self, var):
484
506
  h = NumberList(self.settings[p]['h'] for p in var.params)
485
- settings = self.settings[var.params[0]]
486
- n_samples = settings['n_samples']
487
- distribution = settings['distribution']
507
+
508
+ n_samples = self.defaults['n_samples']
509
+ distribution = self.defaults['distribution']
488
510
 
489
511
  step = var.current_step
490
512
 
@@ -1 +1 @@
1
- from .higher_order_newton import HigherOrderNewton
1
+ from .higher_order_newton import HigherOrderNewton