torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,196 +0,0 @@
1
- from collections import deque
2
- from functools import partial
3
- from operator import itemgetter
4
- from typing import Literal
5
-
6
- import torch
7
-
8
- from ...core import Chainable, Module, Transform, Var, apply_transform
9
- from ...utils import NumberList, TensorList, as_tensorlist
10
- from .lbfgs import _adaptive_damping, lbfgs
11
-
12
-
13
- @torch.no_grad
14
- def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
- assert var.closure is not None
16
- with torch.enable_grad(): var.closure()
17
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
18
- s_k = var.params - prev_params
19
- y_k = grad - prev_grad
20
- ys_k = s_k.dot(y_k)
21
-
22
- if damping:
23
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
24
-
25
- if ys_k > 1e-10:
26
- s_history.append(s_k)
27
- y_history.append(y_k)
28
- sy_history.append(ys_k)
29
-
30
-
31
-
32
- class OnlineLBFGS(Module):
33
- """Online L-BFGS.
34
- Parameter and gradient differences are sampled from the same mini-batch by performing an extra forward and backward pass.
35
- However I did a bunch of experiments and the online part doesn't seem to help. Normal L-BFGS is usually still
36
- better because it performs twice as many steps, and it is reasonably stable with normalization or grafting.
37
-
38
- Args:
39
- history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
40
- sample_grads (str, optional):
41
- - "before" - samples current mini-batch gradient at previous and current parameters, calculates y_k
42
- and adds it to history before stepping.
43
- - "after" - samples current mini-batch gradient at parameters before stepping and after updating parameters.
44
- s_k and y_k are added after parameter update, therefore they are delayed by 1 step.
45
-
46
- In practice both modes behave very similarly. Defaults to 'before'.
47
- tol (float | None, optional):
48
- tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
49
- damping (bool, optional):
50
- whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
51
- init_damping (float, optional):
52
- initial damping for adaptive dampening. Defaults to 0.9.
53
- eigval_bounds (tuple, optional):
54
- eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
55
- params_beta (float | None, optional):
56
- if not None, EMA of parameters is used for preconditioner update. Defaults to None.
57
- grads_beta (float | None, optional):
58
- if not None, EMA of gradients is used for preconditioner update. Defaults to None.
59
- update_freq (int, optional):
60
- how often to update L-BFGS history. Defaults to 1.
61
- z_beta (float | None, optional):
62
- optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
63
- inner (Chainable | None, optional):
64
- optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
65
- """
66
- def __init__(
67
- self,
68
- history_size=10,
69
- sample_grads: Literal['before', 'after'] = 'before',
70
- tol: float | None = 1e-10,
71
- damping: bool = False,
72
- init_damping=0.9,
73
- eigval_bounds=(0.5, 50),
74
- z_beta: float | None = None,
75
- inner: Chainable | None = None,
76
- ):
77
- defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, sample_grads=sample_grads, z_beta=z_beta)
78
- super().__init__(defaults)
79
-
80
- self.global_state['s_history'] = deque(maxlen=history_size)
81
- self.global_state['y_history'] = deque(maxlen=history_size)
82
- self.global_state['sy_history'] = deque(maxlen=history_size)
83
-
84
- if inner is not None:
85
- self.set_child('inner', inner)
86
-
87
- def reset(self):
88
- """Resets the internal state of the L-SR1 module."""
89
- # super().reset() # Clears self.state (per-parameter) if any, and "step"
90
- # Re-initialize L-SR1 specific global state
91
- self.state.clear()
92
- self.global_state['step'] = 0
93
- self.global_state['s_history'].clear()
94
- self.global_state['y_history'].clear()
95
- self.global_state['sy_history'].clear()
96
-
97
- @torch.no_grad
98
- def step(self, var):
99
- assert var.closure is not None
100
-
101
- params = as_tensorlist(var.params)
102
- update = as_tensorlist(var.get_update())
103
- step = self.global_state.get('step', 0)
104
- self.global_state['step'] = step + 1
105
-
106
- # history of s and k
107
- s_history: deque[TensorList] = self.global_state['s_history']
108
- y_history: deque[TensorList] = self.global_state['y_history']
109
- sy_history: deque[torch.Tensor] = self.global_state['sy_history']
110
-
111
- tol, damping, init_damping, eigval_bounds, sample_grads, z_beta = itemgetter(
112
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'sample_grads', 'z_beta')(self.settings[params[0]])
113
-
114
- # sample gradient at previous params with current mini-batch
115
- if sample_grads == 'before':
116
- prev_params = self.get_state(params, 'prev_params', cls=TensorList)
117
- if step == 0:
118
- s_k = None; y_k = None; ys_k = None
119
- else:
120
- s_k = params - prev_params
121
-
122
- current_params = params.clone()
123
- params.set_(prev_params)
124
- with torch.enable_grad(): var.closure()
125
- y_k = update - params.grad
126
- ys_k = s_k.dot(y_k)
127
- params.set_(current_params)
128
-
129
- if damping:
130
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
131
-
132
- if ys_k > 1e-10:
133
- s_history.append(s_k)
134
- y_history.append(y_k)
135
- sy_history.append(ys_k)
136
-
137
- prev_params.copy_(params)
138
-
139
- # use previous s_k, y_k pair, samples gradient at current batch before and after updating parameters
140
- elif sample_grads == 'after':
141
- if len(s_history) == 0:
142
- s_k = None; y_k = None; ys_k = None
143
- else:
144
- s_k = s_history[-1]
145
- y_k = y_history[-1]
146
- ys_k = s_k.dot(y_k)
147
-
148
- # this will run after params are updated by Modular after running all future modules
149
- var.post_step_hooks.append(
150
- partial(
151
- _store_sk_yk_after_step_hook,
152
- prev_params=params.clone(),
153
- prev_grad=update.clone(),
154
- damping=damping,
155
- init_damping=init_damping,
156
- eigval_bounds=eigval_bounds,
157
- s_history=s_history,
158
- y_history=y_history,
159
- sy_history=sy_history,
160
- ))
161
-
162
- else:
163
- raise ValueError(sample_grads)
164
-
165
- # step with inner module before applying preconditioner
166
- if self.children:
167
- update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
168
-
169
- # tolerance on gradient difference to avoid exploding after converging
170
- if tol is not None:
171
- if y_k is not None and y_k.abs().global_max() <= tol:
172
- var.update = update # may have been updated by inner module, probably makes sense to use it here?
173
- return var
174
-
175
- # lerp initial H^-1 @ q guess
176
- z_ema = None
177
- if z_beta is not None:
178
- z_ema = self.get_state(params, 'z_ema', cls=TensorList)
179
-
180
- # precondition
181
- dir = lbfgs(
182
- tensors_=as_tensorlist(update),
183
- s_history=s_history,
184
- y_history=y_history,
185
- sy_history=sy_history,
186
- y_k=y_k,
187
- ys_k=ys_k,
188
- z_beta = z_beta,
189
- z_ema = z_ema,
190
- step=step
191
- )
192
-
193
- var.update = dir
194
-
195
- return var
196
-
@@ -1,164 +0,0 @@
1
- import warnings
2
- from abc import ABC, abstractmethod
3
- from collections.abc import Callable, Sequence
4
- from functools import partial
5
- from typing import Literal
6
-
7
- import torch
8
-
9
- from ...core import Modular, Module, Var
10
- from ...utils import NumberList, TensorList
11
- from ...utils.derivatives import jacobian_wrt
12
- from ..grad_approximation import GradApproximator, GradTarget
13
-
14
-
15
- class Reformulation(Module, ABC):
16
- def __init__(self, defaults):
17
- super().__init__(defaults)
18
-
19
- @abstractmethod
20
- def closure(self, backward: bool, closure: Callable, params:list[torch.Tensor], var: Var) -> tuple[float | torch.Tensor, Sequence[torch.Tensor] | None]:
21
- """returns loss and gradient, if backward is False then gradient can be None"""
22
-
23
- def pre_step(self, var: Var) -> Var | None:
24
- """This runs once before each step, whereas `closure` may run multiple times per step if further modules
25
- evaluate gradients at multiple points. This is useful for example to pre-generate new random perturbations."""
26
- return var
27
-
28
- def step(self, var):
29
- ret = self.pre_step(var)
30
- if isinstance(ret, Var): var = ret
31
-
32
- if var.closure is None: raise RuntimeError("Reformulation requires closure")
33
- params, closure = var.params, var.closure
34
-
35
-
36
- def modified_closure(backward=True):
37
- loss, grad = self.closure(backward, closure, params, var)
38
-
39
- if grad is not None:
40
- for p,g in zip(params, grad):
41
- p.grad = g
42
-
43
- return loss
44
-
45
- var.closure = modified_closure
46
- return var
47
-
48
-
49
- def _decay_sigma_(self: Module, params):
50
- for p in params:
51
- state = self.state[p]
52
- settings = self.settings[p]
53
- state['sigma'] *= settings['decay']
54
-
55
- def _generate_perturbations_to_state_(self: Module, params: TensorList, n_samples, sigmas, generator):
56
- perturbations = [params.sample_like(generator=generator) for _ in range(n_samples)]
57
- torch._foreach_mul_([p for l in perturbations for p in l], [v for vv in sigmas for v in [vv]*n_samples])
58
- for param, prt in zip(params, zip(*perturbations)):
59
- self.state[param]['perturbations'] = prt
60
-
61
- def _clear_state_hook(optimizer: Modular, var: Var, self: Module):
62
- for m in optimizer.unrolled_modules:
63
- if m is not self:
64
- m.reset()
65
-
66
- class GaussianHomotopy(Reformulation):
67
- def __init__(
68
- self,
69
- n_samples: int,
70
- init_sigma: float,
71
- tol: float | None = 1e-4,
72
- decay=0.5,
73
- max_steps: int | None = None,
74
- clear_state=True,
75
- seed: int | None = None,
76
- ):
77
- defaults = dict(n_samples=n_samples, init_sigma=init_sigma, tol=tol, decay=decay, max_steps=max_steps, clear_state=clear_state, seed=seed)
78
- super().__init__(defaults)
79
-
80
-
81
- def _get_generator(self, seed: int | None | torch.Generator, params: list[torch.Tensor]):
82
- if 'generator' not in self.global_state:
83
- if isinstance(seed, torch.Generator): self.global_state['generator'] = seed
84
- elif seed is not None: self.global_state['generator'] = torch.Generator(params[0].device).manual_seed(seed)
85
- else: self.global_state['generator'] = None
86
- return self.global_state['generator']
87
-
88
- def pre_step(self, var):
89
- params = TensorList(var.params)
90
- settings = self.settings[params[0]]
91
- n_samples = settings['n_samples']
92
- init_sigma = [self.settings[p]['init_sigma'] for p in params]
93
- sigmas = self.get_state(params, 'sigma', init=init_sigma)
94
-
95
- if any('perturbations' not in self.state[p] for p in params):
96
- generator = self._get_generator(settings['seed'], params)
97
- _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
98
-
99
- # sigma decay rules
100
- max_steps = settings['max_steps']
101
- decayed = False
102
- if max_steps is not None and max_steps > 0:
103
- level_steps = self.global_state['level_steps'] = self.global_state.get('level_steps', 0) + 1
104
- if level_steps > max_steps:
105
- self.global_state['level_steps'] = 0
106
- _decay_sigma_(self, params)
107
- decayed = True
108
-
109
- tol = settings['tol']
110
- if tol is not None and not decayed:
111
- if not any('prev_params' in self.state[p] for p in params):
112
- prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
113
- else:
114
- prev_params = self.get_state(params, 'prev_params', cls=TensorList, init='param')
115
- s = params - prev_params
116
-
117
- if s.abs().global_max() <= tol:
118
- _decay_sigma_(self, params)
119
- decayed = True
120
-
121
- prev_params.copy_(params)
122
-
123
- if decayed:
124
- generator = self._get_generator(settings['seed'], params)
125
- _generate_perturbations_to_state_(self, params=params, n_samples=n_samples, sigmas=sigmas, generator=generator)
126
- if settings['clear_state']:
127
- var.post_step_hooks.append(partial(_clear_state_hook, self=self))
128
-
129
- @torch.no_grad
130
- def closure(self, backward, closure, params, var):
131
- params = TensorList(params)
132
-
133
- settings = self.settings[params[0]]
134
- n_samples = settings['n_samples']
135
-
136
- perturbations = list(zip(*(self.state[p]['perturbations'] for p in params)))
137
-
138
- loss = None
139
- grad = None
140
- for i in range(n_samples):
141
- prt = perturbations[i]
142
-
143
- params.add_(prt)
144
- if backward:
145
- with torch.enable_grad(): l = closure()
146
- if grad is None: grad = params.grad
147
- else: grad += params.grad
148
-
149
- else:
150
- l = closure(False)
151
-
152
- if loss is None: loss = l
153
- else: loss = loss+l
154
-
155
- params.sub_(prt)
156
-
157
- assert loss is not None
158
- if n_samples > 1:
159
- loss = loss / n_samples
160
- if backward:
161
- assert grad is not None
162
- grad.div_(n_samples)
163
-
164
- return loss, grad