torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -6,35 +6,18 @@ import torch
6
6
  from ...core import Chainable
7
7
  from ...utils import vec_to_tensors, TensorList
8
8
  from ..optimizers.shampoo import _merge_small_dims
9
- from .projection import Projection
9
+ from ..projections import ProjectionBase
10
10
 
11
11
 
12
- class VectorProjection(Projection):
13
- """
14
- flattens and concatenates all parameters into a vector
15
- """
16
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
17
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
18
12
 
19
- @torch.no_grad
20
- def project(self, tensors, var, current):
21
- return [torch.cat([u.view(-1) for u in tensors], dim=-1)]
22
-
23
- @torch.no_grad
24
- def unproject(self, tensors, var, current):
25
- return vec_to_tensors(vec=tensors[0], reference=var.params)
26
-
27
-
28
-
29
- class TensorizeProjection(Projection):
13
+ class TensorizeProjection(ProjectionBase):
30
14
  """flattens and concatenates all parameters into a vector and then reshapes it into a tensor"""
31
15
  def __init__(self, modules: Chainable, max_side: int, project_update=True, project_params=False, project_grad=False):
32
16
  defaults = dict(max_side=max_side)
33
17
  super().__init__(modules, defaults=defaults, project_update=project_update, project_params=project_params, project_grad=project_grad)
34
18
 
35
19
  @torch.no_grad
36
- def project(self, tensors, var, current):
37
- params = var.params
20
+ def project(self, tensors, params, grads, loss, states, settings, current):
38
21
  max_side = self.settings[params[0]]['max_side']
39
22
  num_elems = sum(t.numel() for t in tensors)
40
23
 
@@ -60,23 +43,23 @@ class TensorizeProjection(Projection):
60
43
  return [vec.view(dims)]
61
44
 
62
45
  @torch.no_grad
63
- def unproject(self, tensors, var, current):
46
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
64
47
  remainder = self.global_state['remainder']
65
48
  # warnings.warn(f'{tensors[0].shape = }')
66
- vec = tensors[0].view(-1)
49
+ vec = projected_tensors[0].view(-1)
67
50
  if remainder > 0: vec = vec[:-remainder]
68
- return vec_to_tensors(vec, var.params)
51
+ return vec_to_tensors(vec, params)
69
52
 
70
- class BlockPartition(Projection):
53
+ class BlockPartition(ProjectionBase):
71
54
  """splits parameters into blocks (for now flatttens them and chunks)"""
72
55
  def __init__(self, modules: Chainable, max_size: int, batched: bool = False, project_update=True, project_params=False, project_grad=False):
73
56
  defaults = dict(max_size=max_size, batched=batched)
74
57
  super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad, defaults=defaults)
75
58
 
76
59
  @torch.no_grad
77
- def project(self, tensors, var, current):
60
+ def project(self, tensors, params, grads, loss, states, settings, current):
78
61
  partitioned = []
79
- for p,t in zip(var.params, tensors):
62
+ for p,t in zip(params, tensors):
80
63
  settings = self.settings[p]
81
64
  max_size = settings['max_size']
82
65
  n = t.numel()
@@ -101,10 +84,10 @@ class BlockPartition(Projection):
101
84
  return partitioned
102
85
 
103
86
  @torch.no_grad
104
- def unproject(self, tensors, var, current):
105
- ti = iter(tensors)
87
+ def unproject(self, projected_tensors, params, grads, loss, states, settings, current):
88
+ ti = iter(projected_tensors)
106
89
  unprojected = []
107
- for p in var.params:
90
+ for p in params:
108
91
  settings = self.settings[p]
109
92
  n = p.numel()
110
93
 
@@ -124,28 +107,3 @@ class BlockPartition(Projection):
124
107
 
125
108
  return unprojected
126
109
 
127
-
128
- class TensorNormsProjection(Projection):
129
- def __init__(self, modules: Chainable, project_update=True, project_params=False, project_grad=False):
130
- super().__init__(modules, project_update=project_update, project_params=project_params, project_grad=project_grad)
131
-
132
- @torch.no_grad
133
- def project(self, tensors, var, current):
134
- orig = self.get_state(var.params, f'{current}_orig')
135
- torch._foreach_copy_(orig, tensors)
136
-
137
- norms = torch._foreach_norm(tensors)
138
- self.get_state(var.params, f'{current}_orig_norms', cls=TensorList).set_(norms)
139
-
140
- return [torch.stack(norms)]
141
-
142
- @torch.no_grad
143
- def unproject(self, tensors, var, current):
144
- orig = self.get_state(var.params, f'{current}_orig')
145
- orig_norms = torch.stack(self.get_state(var.params, f'{current}_orig_norms'))
146
- target_norms = tensors[0]
147
-
148
- orig_norms = torch.where(orig_norms == 0, 1, orig_norms)
149
-
150
- torch._foreach_mul_(orig, (target_norms/orig_norms).detach().cpu().tolist())
151
- return orig
@@ -38,14 +38,19 @@ def apply_subspace_preconditioner(
38
38
  return basis @ update_projected # d
39
39
 
40
40
  class RandomSubspacePreconditioning(Transform):
41
- """Whitens in random slowly changing subspace. Please note that this is experimental and isn't guaranteed to work."""
41
+ """Whitens in random slowly changing subspace.
42
+
43
+ .. warning::
44
+ Experimental and this is a barebones implementation.
45
+
46
+ """
42
47
  def __init__(self, k: int, beta: float | None = 0.99, basis_beta: float | None = 0.99, inner: Chainable | None = None):
43
48
  defaults = dict(k=k, beta=beta, basis_beta=basis_beta)
44
49
  super().__init__(defaults, uses_grad=False)
45
50
 
46
51
  if inner is not None: self.set_child('inner', inner)
47
52
 
48
- def apply(self, tensors, params, grads, loss, states, settings):
53
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
49
54
  settings = settings[0]
50
55
  g = torch.cat([t.view(-1) for t in tensors])
51
56
  k = settings['k']
@@ -79,7 +84,9 @@ class RandomSubspacePreconditioning(Transform):
79
84
 
80
85
  class HistorySubspacePreconditioning(Transform):
81
86
  """Whitens in subspace spanned by history of gradient differences.
82
- Please note that this is experimental and isn't guaranteed to work.
87
+
88
+ .. warning::
89
+ Experimental and this is a barebones implementation.
83
90
 
84
91
  Args:
85
92
  beta - for preconditioner itself in the basis.
@@ -91,7 +98,7 @@ class HistorySubspacePreconditioning(Transform):
91
98
 
92
99
  if inner is not None: self.set_child('inner', inner)
93
100
 
94
- def apply(self, tensors, params, grads, loss, states, settings):
101
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
95
102
  settings = settings[0]
96
103
 
97
104
  g = torch.cat([t.view(-1) for t in tensors])
@@ -6,17 +6,21 @@ from ...core import Chainable, TensorwiseTransform
6
6
  from ...utils.linalg import matrix_power_eigh
7
7
 
8
8
 
9
- class TAda(TensorwiseTransform):
10
- """3rd order whitening (maybe normalizes skewness). Please note that this is experimental and isn't guaranteed to work."""
9
+ class TensorAdagrad(TensorwiseTransform):
10
+ """3rd order whitening (maybe normalizes skewness, but don't quote me on it).
11
+
12
+ .. warning::
13
+ Experimental.
14
+ """
11
15
  def __init__(self, history_size: int = 100, reg: float = 1e-8, update_freq: int = 1, concat_params: bool = True, inner: Chainable | None = None):
12
16
  defaults = dict(history_size=history_size, reg=reg)
13
17
  super().__init__(defaults, uses_grad=False, update_freq=update_freq, inner=inner, concat_params=concat_params)
14
18
 
15
19
  @torch.no_grad
16
- def update_tensor(self, tensor, param, grad, loss, state, settings):
17
- reg = settings['reg']
20
+ def update_tensor(self, tensor, param, grad, loss, state, setting):
21
+ reg = setting['reg']
18
22
  if 'history' not in state:
19
- state['history'] = deque(maxlen=settings['history_size'])
23
+ state['history'] = deque(maxlen=setting['history_size'])
20
24
 
21
25
  g = tensor.view(-1)
22
26
  history = state['history']
@@ -32,7 +36,7 @@ class TAda(TensorwiseTransform):
32
36
  state['outer'] = outer.add_(I)
33
37
 
34
38
  @torch.no_grad
35
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
39
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
36
40
  outer = state['outer']
37
41
  P = matrix_power_eigh(outer, -1/2)
38
42
  return (P @ tensor.ravel()).view_as(tensor)
@@ -7,8 +7,9 @@ storage is always indicated in the docstring.
7
7
 
8
8
  Additional functional variants are present in most module files, e.g. `adam_`, `rmsprop_`, `lion_`, etc.
9
9
  """
10
-
11
- from collections.abc import Callable, Sequence
10
+ from collections.abc import Callable
11
+ from typing import overload
12
+ import torch
12
13
 
13
14
  from ..utils import NumberList, TensorList
14
15
 
@@ -206,4 +207,13 @@ def sqrt_centered_ema_sq_(
206
207
  ema_sq_fn=lambda *a, **kw: centered_ema_sq_(*a, **kw, exp_avg_=exp_avg_)
207
208
  )
208
209
 
210
+ @overload
211
+ def safe_scaling_(tensors_: torch.Tensor) -> torch.Tensor: ...
212
+ @overload
213
+ def safe_scaling_(tensors_: TensorList) -> TensorList: ...
214
+ def safe_scaling_(tensors_: torch.Tensor | TensorList):
215
+ if isinstance(tensors_, torch.Tensor): scale = 1 / tensors_.abs().sum()
216
+ else: scale = 1 / tensors_.abs().global_sum()
217
+ scale = scale.clip(min=torch.finfo(tensors_[0].dtype).eps, max=1)
218
+ return tensors_.mul_(scale)
209
219
 
@@ -77,8 +77,11 @@ def _central4(closure: Callable[..., float], param:torch.Tensor, idx: int, h, v_
77
77
  return v_0, v_plus1, (v_minus2 - 8*v_minus1 + 8*v_plus1 - v_plus2) / (12 * h)
78
78
 
79
79
  _FD_FUNCS = {
80
+ "forward": _forward2,
80
81
  "forward2": _forward2,
82
+ "backward": _backward2,
81
83
  "backward2": _backward2,
84
+ "central": _central2,
82
85
  "central2": _central2,
83
86
  "central3": _central2, # they are the same
84
87
  "forward3": _forward3,
@@ -88,19 +91,43 @@ _FD_FUNCS = {
88
91
 
89
92
 
90
93
  class FDM(GradApproximator):
91
- """Approximate gradients via finite difference method
94
+ """Approximate gradients via finite difference method.
95
+
96
+ .. note::
97
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
98
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
92
99
 
93
100
  Args:
94
101
  h (float, optional): magnitude of parameter perturbation. Defaults to 1e-3.
95
102
  formula (_FD_Formula, optional): finite difference formula. Defaults to 'central2'.
96
103
  target (GradTarget, optional): what to set on var. Defaults to 'closure'.
104
+
105
+ Examples:
106
+ plain FDM:
107
+
108
+ .. code-block:: python
109
+
110
+ fdm = tz.Modular(model.parameters(), tz.m.FDM(), tz.m.LR(1e-2))
111
+
112
+ Any gradient-based method can use FDM-estimated gradients seamlessly.
113
+
114
+ .. code-block:: python
115
+
116
+ fdm_ncg = tz.Modular(
117
+ model.parameters(),
118
+ tz.m.FDM(),
119
+ # set hvp_method to "forward" so that it
120
+ # uses gradient difference instead of autograd
121
+ tz.m.NewtonCG(hvp_method="forward"),
122
+ tz.m.Backtracking()
123
+ )
97
124
  """
98
- def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central2', target: GradTarget = 'closure'):
125
+ def __init__(self, h: float=1e-3, formula: _FD_Formula = 'central', target: GradTarget = 'closure'):
99
126
  defaults = dict(h=h, formula=formula)
100
127
  super().__init__(defaults, target=target)
101
128
 
102
129
  @torch.no_grad
103
- def approximate(self, closure, params, loss, var):
130
+ def approximate(self, closure, params, loss):
104
131
  grads = []
105
132
  loss_approx = None
106
133
 
@@ -4,14 +4,21 @@ from typing import Any, Literal
4
4
 
5
5
  import torch
6
6
 
7
- from ...utils import Distributions, NumberList, TensorList, generic_eq
7
+ from ...utils import Distributions, NumberList, TensorList
8
8
  from ...utils.derivatives import jvp, jvp_fd_central, jvp_fd_forward
9
9
  from .grad_approximator import GradApproximator, GradTarget
10
10
  from .rfdm import RandomizedFDM
11
11
 
12
12
 
13
13
  class ForwardGradient(RandomizedFDM):
14
- """Forward gradient method, same as randomized finite difference but directional derivative is estimated via autograd (as jacobian vector product)
14
+ """Forward gradient method.
15
+
16
+ This method samples one or more directional derivatives evaluated via autograd jacobian-vector products. This is very similar to randomized finite difference.
17
+
18
+ .. note::
19
+ This module is a gradient approximator. It modifies the closure to evaluate the estimated gradients,
20
+ and further closure-based modules will use the modified closure. All modules after this will use estimated gradients.
21
+
15
22
 
16
23
  Args:
17
24
  n_samples (int, optional): number of random gradient samples. Defaults to 1.
@@ -24,6 +31,9 @@ class ForwardGradient(RandomizedFDM):
24
31
  how to calculate jacobian vector product, note that with `forward` and 'central' this is equivalent to randomized finite difference. Defaults to 'autograd'.
25
32
  h (float, optional): finite difference step size of jvp_method is set to `forward` or `central`. Defaults to 1e-3.
26
33
  target (GradTarget, optional): what to set on var. Defaults to "closure".
34
+
35
+ References:
36
+ Baydin, A. G., Pearlmutter, B. A., Syme, D., Wood, F., & Torr, P. (2022). Gradients without backpropagation. arXiv preprint arXiv:2202.08587.
27
37
  """
28
38
  PRE_MULTIPLY_BY_H = False
29
39
  def __init__(
@@ -41,7 +51,7 @@ class ForwardGradient(RandomizedFDM):
41
51
  self.defaults['jvp_method'] = jvp_method
42
52
 
43
53
  @torch.no_grad
44
- def approximate(self, closure, params, loss, var):
54
+ def approximate(self, closure, params, loss):
45
55
  params = TensorList(params)
46
56
  loss_approx = None
47
57
 
@@ -14,17 +14,62 @@ class GradApproximator(Module, ABC):
14
14
  """Base class for gradient approximations.
15
15
  This is an abstract class, to use it, subclass it and override `approximate`.
16
16
 
17
+ GradientApproximator modifies the closure to evaluate the estimated gradients,
18
+ and further closure-based modules will use the modified closure.
19
+
17
20
  Args:
18
21
  defaults (dict[str, Any] | None, optional): dict with defaults. Defaults to None.
19
22
  target (str, optional):
20
23
  whether to set `var.grad`, `var.update` or 'var.closure`. Defaults to 'closure'.
21
- """
24
+
25
+ Example:
26
+
27
+ Basic SPSA method implementation.
28
+
29
+ .. code-block:: python
30
+
31
+ class SPSA(GradApproximator):
32
+ def __init__(self, h=1e-3):
33
+ defaults = dict(h=h)
34
+ super().__init__(defaults)
35
+
36
+ @torch.no_grad
37
+ def approximate(self, closure, params, loss):
38
+ perturbation = [rademacher_like(p) * self.settings[p]['h'] for p in params]
39
+
40
+ # evaluate params + perturbation
41
+ torch._foreach_add_(params, perturbation)
42
+ loss_plus = closure(False)
43
+
44
+ # evaluate params - perturbation
45
+ torch._foreach_sub_(params, perturbation)
46
+ torch._foreach_sub_(params, perturbation)
47
+ loss_minus = closure(False)
48
+
49
+ # restore original params
50
+ torch._foreach_add_(params, perturbation)
51
+
52
+ # calculate SPSA gradients
53
+ spsa_grads = []
54
+ for p, pert in zip(params, perturbation):
55
+ settings = self.settings[p]
56
+ h = settings['h']
57
+ d = (loss_plus - loss_minus) / (2*(h**2))
58
+ spsa_grads.append(pert * d)
59
+
60
+ # returns tuple: (grads, loss, loss_approx)
61
+ # loss must be with initial parameters
62
+ # since we only evaluated loss with perturbed parameters
63
+ # we only have loss_approx
64
+ return spsa_grads, None, loss_plus
65
+
66
+ """
22
67
  def __init__(self, defaults: dict[str, Any] | None = None, target: GradTarget = 'closure'):
23
68
  super().__init__(defaults)
24
69
  self._target: GradTarget = target
25
70
 
26
71
  @abstractmethod
27
- def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None, var: Var) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
72
+ def approximate(self, closure: Callable, params: list[torch.Tensor], loss: _Scalar | None) -> tuple[Iterable[torch.Tensor], _Scalar | None, _Scalar | None]:
28
73
  """Returns a tuple: (grad, loss, loss_approx), make sure this resets parameters to their original values!"""
29
74
 
30
75
  def pre_step(self, var: Var) -> Var | None:
@@ -45,9 +90,9 @@ class GradApproximator(Module, ABC):
45
90
  def approx_closure(backward=True):
46
91
  if backward:
47
92
  # set loss to None because closure might be evaluated at different points
48
- grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None, var=var)
93
+ grad, l, l_approx = self.approximate(closure=closure, params=params, loss=None)
49
94
  for p, g in zip(params, grad): p.grad = g
50
- return l if l is not None else l_approx
95
+ return l if l is not None else closure(False)
51
96
  return closure(False)
52
97
 
53
98
  var.closure = approx_closure
@@ -55,7 +100,7 @@ class GradApproximator(Module, ABC):
55
100
 
56
101
  # if var.grad is not None:
57
102
  # warnings.warn('Using grad approximator when `var.grad` is already set.')
58
- grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss, var=var)
103
+ grad,loss,loss_approx = self.approximate(closure=closure, params=params, loss=loss)
59
104
  if loss_approx is not None: var.loss_approx = loss_approx
60
105
  if loss is not None: var.loss = var.loss_approx = loss
61
106
  if self._target == 'grad': var.grad = list(grad)
@@ -63,4 +108,4 @@ class GradApproximator(Module, ABC):
63
108
  else: raise ValueError(self._target)
64
109
  return var
65
110
 
66
- _FD_Formula = Literal['forward2', 'backward2', 'forward3', 'backward3', 'central2', 'central4']
111
+ _FD_Formula = Literal['forward', 'forward2', 'backward', 'backward2', 'central', 'central2', 'central3', 'forward3', 'backward3', 'central4', 'forward4', 'forward5', 'bspsa5']