torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -0,0 +1,93 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+ from functools import partial
4
+ import torch
5
+
6
+ from ...utils import TensorList, NumberList
7
+ from ..grad_approximation.grad_approximator import GradApproximator, GradTarget
8
+
9
+ class SPSA1(GradApproximator):
10
+ """One-measurement variant of SPSA. Unlike standard two-measurement SPSA, the estimated
11
+ gradient often won't be a descent direction, however the expectation is biased towards
12
+ the descent direction. Therefore this variant of SPSA is only recommended for a specific
13
+ class of problems where the objective function changes on each evaluation,
14
+ for example feedback control problems.
15
+
16
+ Args:
17
+ h (float, optional):
18
+ finite difference step size, recommended to set to same value as learning rate. Defaults to 1e-3.
19
+ n_samples (int, optional): number of random samples. Defaults to 1.
20
+ eps (float, optional): measurement noise estimate. Defaults to 1e-8.
21
+ seed (int | None | torch.Generator, optional): random seed. Defaults to None.
22
+ target (GradTarget, optional): what to set on closure. Defaults to "closure".
23
+
24
+ Reference:
25
+ [SPALL, JAMES C. "A One-measurement Form of Simultaneous Stochastic Approximation](https://www.jhuapl.edu/spsa/PDF-SPSA/automatica97_one_measSPSA.pdf)."
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ h: float = 1e-3,
31
+ n_samples: int = 1,
32
+ eps: float = 1e-8, # measurement noise
33
+ pre_generate = False,
34
+ seed: int | None | torch.Generator = None,
35
+ target: GradTarget = "closure",
36
+ ):
37
+ defaults = dict(h=h, eps=eps, n_samples=n_samples, pre_generate=pre_generate, seed=seed)
38
+ super().__init__(defaults, target=target)
39
+
40
+
41
+ def pre_step(self, var):
42
+
43
+ if self.defaults['pre_generate']:
44
+
45
+ params = TensorList(var.params)
46
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
47
+
48
+ n_samples = self.defaults['n_samples']
49
+ h = self.get_settings(var.params, 'h')
50
+
51
+ perturbations = [params.sample_like(distribution='rademacher', generator=generator) for _ in range(n_samples)]
52
+ torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in h for v in [vv]*n_samples])
53
+
54
+ for param, prt in zip(params, zip(*perturbations)):
55
+ self.state[param]['perturbations'] = prt
56
+
57
+ @torch.no_grad
58
+ def approximate(self, closure, params, loss):
59
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
60
+
61
+ params = TensorList(params)
62
+ orig_params = params.clone() # store to avoid small changes due to float imprecision
63
+ loss_approx = None
64
+
65
+ h, eps = self.get_settings(params, "h", "eps", cls=NumberList)
66
+ n_samples = self.defaults['n_samples']
67
+
68
+ default = [None]*n_samples
69
+ # perturbations are pre-multiplied by h
70
+ perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
71
+
72
+ grad = None
73
+ for i in range(n_samples):
74
+ prt = perturbations[i]
75
+
76
+ if prt[0] is None:
77
+ prt = params.sample_like('rademacher', generator=generator).mul_(h)
78
+
79
+ else: prt = TensorList(prt)
80
+
81
+ params += prt
82
+ L = closure(False)
83
+ params.copy_(orig_params)
84
+
85
+ sample = prt * ((L + eps) / h)
86
+ if grad is None: grad = sample
87
+ else: grad += sample
88
+
89
+ assert grad is not None
90
+ if n_samples > 1: grad.div_(n_samples)
91
+
92
+ # mean if got per-sample values
93
+ return grad, loss, loss_approx
@@ -5,7 +5,7 @@ import torch
5
5
 
6
6
  from ...core import Chainable
7
7
  from ...utils import vec_to_tensors, TensorList
8
- from ..optimizers.shampoo import _merge_small_dims
8
+ from ..adaptive.shampoo import _merge_small_dims
9
9
  from ..projections import ProjectionBase
10
10
 
11
11
 
@@ -9,9 +9,17 @@ Additional functional variants are present in most module files, e.g. `adam_`, `
9
9
  """
10
10
  from collections.abc import Callable
11
11
  from typing import overload
12
+
12
13
  import torch
13
14
 
14
- from ..utils import NumberList, TensorList
15
+ from ..utils import (
16
+ NumberList,
17
+ TensorList,
18
+ generic_finfo_eps,
19
+ generic_max,
20
+ generic_sum,
21
+ tofloat,
22
+ )
15
23
 
16
24
  inf = float('inf')
17
25
 
@@ -87,10 +95,10 @@ def root(tensors_:TensorList, p:float, inplace: bool):
87
95
  if p == 1: return tensors_.abs_()
88
96
  if p == 2: return tensors_.sqrt_()
89
97
  return tensors_.pow_(1/p)
90
- else:
91
- if p == 1: return tensors_.abs()
92
- if p == 2: return tensors_.sqrt()
93
- return tensors_.pow(1/p)
98
+
99
+ if p == 1: return tensors_.abs()
100
+ if p == 2: return tensors_.sqrt()
101
+ return tensors_.pow(1/p)
94
102
 
95
103
 
96
104
  def ema_(
@@ -207,13 +215,41 @@ def sqrt_centered_ema_sq_(
207
215
  ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
208
216
  )
209
217
 
210
- @overload
211
- def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
212
- @overload
213
- def safe_scaling_(tensors_: TensorList) -> TensorList: ...
214
- def safe_scaling_(tensors_: torch.Tensor | TensorList):
215
- if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
216
- else: scale = 1 / tensors_.abs().global_sum()
217
- scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
218
- return tensors_.mul_(scale)
218
+ def initial_step_size(tensors: torch.Tensor | TensorList, eps=None) -> float:
219
+ """initial scaling taken from pytorch L-BFGS to avoid requiring a lot of line search iterations,
220
+ this version is safer and makes sure largest value isn't smaller than epsilon."""
221
+ tensors_abs = tensors.abs()
222
+ tensors_sum = generic_sum(tensors_abs)
223
+ tensors_max = generic_max(tensors_abs)
224
+
225
+ feps = generic_finfo_eps(tensors)
226
+ if eps is None: eps = feps
227
+ else: eps = max(eps, feps)
228
+
229
+ # scale should not make largest value smaller than epsilon
230
+ min = eps / tensors_max
231
+ if min >= 1: return 1.0
232
+
233
+ scale = 1 / tensors_sum
234
+ scale = scale.clip(min=min.item(), max=1)
235
+ return scale.item()
236
+
237
+
238
+ def epsilon_step_size(tensors: torch.Tensor | TensorList, alpha=1e-7) -> float:
239
+ """makes sure largest value isn't smaller than epsilon."""
240
+ tensors_abs = tensors.abs()
241
+ tensors_max = generic_max(tensors_abs)
242
+ if tensors_max < alpha: return 1.0
243
+
244
+ if tensors_max < 1: alpha = alpha / tensors_max
245
+ return tofloat(alpha)
246
+
247
+
248
+
249
+ def safe_clip(x: torch.Tensor, min=None):
250
+ """makes sure absolute value of scalar tensor x is not smaller than min"""
251
+ assert x.numel() == 1, x.shape
252
+ if min is None: min = torch.finfo(x.dtype).tiny * 2
219
253
 
254
+ if x.abs() < min: return x.new_full(x.size(), min).copysign(x)
255
+ return x
@@ -1,4 +1,4 @@
1
1
  from .grad_approximator import GradApproximator, GradTarget
2
2
  from .fdm import FDM
3
3
  from .rfdm import RandomizedFDM, MeZO, SPSA, RDSA, GaussianSmoothing
4
- from .forward_gradient import ForwardGradient
4
+ from .forward_gradient import ForwardGradient
@@ -93,7 +93,7 @@ _FD_FUNCS = {
93
93
  class FDM(GradApproximator):
94
94
  """Approximate gradients via finite difference method.
95
95
 
96
- .. note::
96
+ Note:
97
97
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
98
98
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
99
99
 
@@ -103,24 +103,23 @@ class FDM(GradApproximator):
103
103
  target (GradTarget, optional): what to set on var. Defaults to 'closure'.
104
104
 
105
105
  Examples:
106
- plain FDM:
107
-
108
- .. code-block:: python
109
-
110
- fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
111
-
112
- Any gradient-based method can use FDM-estimated gradients seamlessly.
113
-
114
- .. code-block:: python
115
-
116
- fdm_ncg = tz.Modular(
117
- model.parameters(),
118
- tz.m.FDM(),
119
- # set hvp_method to "forward" so that it
120
- # uses gradient difference instead of autograd
121
- tz.m.NewtonCG(hvp_method="forward"),
122
- tz.m.Backtracking()
123
- )
106
+ plain FDM:
107
+
108
+ ```python
109
+ fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
110
+ ```
111
+
112
+ Any gradient-based method can use FDM-estimated gradients.
113
+ ```python
114
+ fdm_ncg = tz.Modular(
115
+ model.parameters(),
116
+ tz.m.FDM(),
117
+ # set hvp_method to "forward" so that it
118
+ # uses gradient difference instead of autograd
119
+ tz.m.NewtonCG(hvp_method="forward"),
120
+ tz.m.Backtracking()
121
+ )
122
+ ```
124
123
  """
125
124
  def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
126
125
  defaults = dict(h=h, formula=formula)
@@ -139,7 +138,7 @@ class FDM(GradApproximator):
139
138
  h = settings['h']
140
139
  fd_fn = _FD_FUNCS[settings['formula']]
141
140
 
142
- p_flat = p.view(-1); g_flat = g.view(-1)
141
+ p_flat = p.ravel(); g_flat = g.ravel()
143
142
  for i in range(len(p_flat)):
144
143
  loss, loss_approx, d = fd_fn(closure=closure, param=p_flat, idx=i, h=h, v_0=loss)
145
144
  g_flat[i] = d
@@ -15,7 +15,7 @@ class ForwardGradient(RandomizedFDM):
15
15
 
16
16
  This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
17
17
 
18
- .. note::
18
+ Note:
19
19
  This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
20
20
  and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
21
21
 
@@ -23,8 +23,6 @@ class ForwardGradient(RandomizedFDM):
23
23
  Args:
24
24
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
25
25
  distribution (Distributions, optional): distribution for random gradient samples. Defaults to "gaussian".
26
- beta (float, optional):
27
- If this is set to a value higher than zero, instead of using directional derivatives in a new random direction on each step, the direction changes gradually with momentum based on this value. This may make it possible to use methods with memory. Defaults to 0.
28
26
  pre_generate (bool, optional):
29
27
  whether to pre-generate gradient samples before each step. If samples are not pre-generated, whenever a method performs multiple closure evaluations, the gradient will be evaluated in different directions each time. Defaults to True.
30
28
  jvp_method (str, optional):
@@ -40,14 +38,13 @@ class ForwardGradient(RandomizedFDM):
40
38
  self,
41
39
  n_samples: int = 1,
42
40
  distribution: Distributions = "gaussian",
43
- beta: float = 0,
44
41
  pre_generate = True,
45
42
  jvp_method: Literal['autograd', 'forward', 'central'] = 'autograd',
46
43
  h: float = 1e-3,
47
44
  target: GradTarget = "closure",
48
45
  seed: int | None | torch.Generator = None,
49
46
  ):
50
- super().__init__(h=h, n_samples=n_samples, distribution=distribution, beta=beta, target=target, pre_generate=pre_generate, seed=seed)
47
+ super().__init__(h=h, n_samples=n_samples, distribution=distribution, target=target, pre_generate=pre_generate, seed=seed)
51
48
  self.defaults['jvp_method'] = jvp_method
52
49
 
53
50
  @torch.no_grad
@@ -62,12 +59,14 @@ class ForwardGradient(RandomizedFDM):
62
59
  distribution = settings['distribution']
63
60
  default = [None]*n_samples
64
61
  perturbations = list(zip(*(self.state[p].get('perturbations', default) for p in params)))
65
- generator = self._get_generator(settings['seed'], params)
62
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
66
63
 
67
64
  grad = None
68
65
  for i in range(n_samples):
69
66
  prt = perturbations[i]
70
- if prt[0] is None: prt = params.sample_like(distribution=distribution, generator=generator)
67
+ if prt[0] is None:
68
+ prt = params.sample_like(distribution=distribution, variance=1, generator=generator)
69
+
71
70
  else: prt = TensorList(prt)
72
71
 
73
72
  if jvp_method == 'autograd':
@@ -24,63 +24,59 @@ class GradApproximator(Module, ABC):
24
24
 
25
25
  Example:
26
26
 
27
- Basic SPSA method implementation.
28
-
29
- .. code-block:: python
30
-
31
- class SPSA(GradApproximator):
32
- def __init__(self, h=1e-3):
33
- defaults = dict(h=h)
34
- super().__init__(defaults)
35
-
36
- @torch.no_grad
37
- def approximate(self, closure, params, loss):
38
- perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
39
-
40
- # evaluate params + perturbation
41
- torch._foreach_add_(params, perturbation)
42
- loss_plus = closure(False)
43
-
44
- # evaluate params - perturbation
45
- torch._foreach_sub_(params, perturbation)
46
- torch._foreach_sub_(params, perturbation)
47
- loss_minus = closure(False)
48
-
49
- # restore original params
50
- torch._foreach_add_(params, perturbation)
51
-
52
- # calculate SPSA gradients
53
- spsa_grads = []
54
- for p, pert in zip(params, perturbation):
55
- settings = self.settings[p]
56
- h = settings['h']
57
- d = (loss_plus - loss_minus) / (2*(h**2))
58
- spsa_grads.append(pert * d)
59
-
60
- # returns tuple: (grads, loss, loss_approx)
61
- # loss must be with initial parameters
62
- # since we only evaluated loss with perturbed parameters
63
- # we only have loss_approx
64
- return spsa_grads, None, loss_plus
65
-
66
- """
27
+ Basic SPSA method implementation.
28
+ ```python
29
+ class SPSA(GradApproximator):
30
+ def __init__(self, h=1e-3):
31
+ defaults = dict(h=h)
32
+ super().__init__(defaults)
33
+
34
+ @torch.no_grad
35
+ def approximate(self, closure, params, loss):
36
+ perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
37
+
38
+ # evaluate params + perturbation
39
+ torch._foreach_add_(params, perturbation)
40
+ loss_plus = closure(False)
41
+
42
+ # evaluate params - perturbation
43
+ torch._foreach_sub_(params, perturbation)
44
+ torch._foreach_sub_(params, perturbation)
45
+ loss_minus = closure(False)
46
+
47
+ # restore original params
48
+ torch._foreach_add_(params, perturbation)
49
+
50
+ # calculate SPSA gradients
51
+ spsa_grads = []
52
+ for p, pert in zip(params, perturbation):
53
+ settings = self.settings[p]
54
+ h = settings['h']
55
+ d = (loss_plus - loss_minus) / (2*(h**2))
56
+ spsa_grads.append(pert * d)
57
+
58
+ # returns tuple: (grads, loss, loss_approx)
59
+ # loss must be with initial parameters
60
+ # since we only evaluated loss with perturbed parameters
61
+ # we only have loss_approx
62
+ return spsa_grads, None, loss_plus
63
+ ```
64
+ """
67
65
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
68
66
  super().__init__(defaults)
69
67
  self._target: GradTarget = target
70
68
 
71
69
  @abstractmethod
72
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
73
- """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
70
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: torch.Tensor | None) -> tuple[Iterable[torch.Tensor], torch.Tensor | None, torch.Tensor | None]:
71
+ """Returns a tuple: ``(grad, loss, loss_approx)``, make sure this resets parameters to their original values!"""
74
72
 
75
- def pre_step(self, var: Var) -> Var | None:
73
+ def pre_step(self, var: Var) -> None:
76
74
  """This runs once before each step, whereas `approximate` may run multiple times per step if further modules
77
75
  evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
78
- return var
79
76
 
80
77
  @torch.no_grad
81
78
  def step(self, var):
82
- ret = self.pre_step(var)
83
- if isinstance(ret, Var): var = ret
79
+ self.pre_step(var)
84
80
 
85
81
  if var.closure is None: raise RuntimeError("Gradient approximation requires closure")
86
82
  params, closure, loss = var.params, var.closure, var.loss
@@ -108,4 +104,4 @@ class GradApproximator(Module, ABC):
108
104
  else: raise ValueError(self._target)
109
105
  return var
110
106
 
111
- _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']
107
+ _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa4']