torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,350 @@
1
+ import math
2
+ import warnings
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import Callable, Mapping, Sequence
5
+ from functools import partial
6
+ from typing import Any, Literal, Protocol, cast, final, overload
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Module, Var, apply_transform
11
+ from ...utils import TensorList, safe_dict_update_, tofloat, vec_to_tensors, generic_finfo, generic_vector_norm
12
+ from ...utils.linalg.linear_operator import LinearOperator
13
+
14
+
15
+ def _flatten_tensors(tensors: list[torch.Tensor]):
16
+ return torch.cat([t.ravel() for t in tensors])
17
+
18
+
19
+
20
+ class _RadiusStrategy(Protocol):
21
+ def __call__(
22
+ self,
23
+ params: Sequence[torch.Tensor],
24
+ closure: Callable,
25
+ f: float,
26
+ g: torch.Tensor,
27
+ H: LinearOperator,
28
+ d: torch.Tensor,
29
+ trust_radius: float,
30
+ eta: float, # 0.0
31
+ nplus: float, # 3.5
32
+ nminus: float, # 0.25
33
+ rho_good: float, # 0.99
34
+ rho_bad: float, # 1e-4
35
+ boundary_tol: float | None,
36
+ init: float,
37
+ state: Mapping[str, Any],
38
+ settings: Mapping[str, Any],
39
+ radius_fn: Callable | None = torch.linalg.vector_norm,
40
+ ) -> tuple[float, bool]:
41
+ """returns (new trust_region value, success).
42
+
43
+ Args:
44
+ params (Sequence[torch.Tensor]): params tensor list
45
+ closure (Callable): closure
46
+ d (torch.Tensor):
47
+ current update vector with current trust_region, which is SUBTRACTED from parameters.
48
+ May be exact solution to (B+yI)x=g, approximate, or a solution to a different subproblem
49
+ (e.g. cubic regularization).
50
+ f (float | torch.Tensor): loss at x0
51
+ g (torch.Tensor): gradient vector
52
+ H (LinearOperator | None): hessian approximation
53
+ trust_radius (float): current trust region value
54
+ eta (float, optional):
55
+ if ratio of actual to predicted rediction is larger than this, step is accepted.
56
+ When :code:`hess_module` is GaussNewton, this can be set to 0. Defaults to 0.15.
57
+ nplus (float, optional): increase factor on successful steps. Defaults to 1.5.
58
+ nminus (float, optional): decrease factor on unsuccessful steps. Defaults to 0.75.
59
+ rho_good (float, optional):
60
+ if ratio of actual to predicted rediction is larger than this, trust region size is multiplied by `nplus`.
61
+ rho_bad (float, optional):
62
+ if ratio of actual to predicted rediction is less than this, trust region size is multiplied by `nminus`.
63
+ boundary_tol (float | None, optional):
64
+ The trust region only increases when suggested step's norm is at least `(1-boundary_tol)*trust_region`.
65
+ This prevents increasing trust region when solution is not on the boundary. Defaults to 1e-2.
66
+ init (float, optional): Initial trust region value. Defaults to 1.
67
+ state (dict, optional): global state of the module for storing persistent info.
68
+ settings (dict, optional): all settings in case this strategy has other settings.
69
+ radius_fn (Callable | None, optional):
70
+ function that accepts ``(d: torch.Tensor)`` and returns the actual region of ``d``
71
+ (e.g. L2) norm for L2 trust region.
72
+ """
73
+ ... # pylint:disable=unnecessary-ellipsis
74
+
75
+ def _get_rho(params: Sequence[torch.Tensor], closure:Callable,
76
+ f: float, g: torch.Tensor, H: LinearOperator, d:torch.Tensor, ):
77
+ """rho is reduction/pred_reduction"""
78
+
79
+ # evaluate actual loss reduction
80
+ update_unflattned = vec_to_tensors(d, params)
81
+ params = TensorList(params)
82
+ x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
83
+
84
+ params -= update_unflattned
85
+ f_star = closure(False)
86
+ params.set_(x0)
87
+
88
+ reduction = f - f_star
89
+
90
+ # expected reduction is g.T @ p + 0.5 * p.T @ B @ p
91
+ Hu = H.matvec(d)
92
+ pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
93
+
94
+ rho = reduction / (pred_reduction.clip(min=torch.finfo(g.dtype).tiny * 2))
95
+ return rho, f_star, reduction, pred_reduction
96
+
97
+ def _get_rho_tensorlist(params: Sequence[torch.Tensor], closure:Callable,
98
+ f: float, g: TensorList, Hvp: Callable[[TensorList], TensorList], d:TensorList):
99
+ """rho is reduction/pred_reduction"""
100
+ params = TensorList(params)
101
+ x0 = params.clone() # same as in line searches, large directions are undone very imprecisely
102
+
103
+ # evaluate before modifying params to not break autograd
104
+ Hu = Hvp(d)
105
+
106
+ # actual f
107
+ params -= d
108
+ f_star = closure(False)
109
+ params.copy_(x0)
110
+
111
+ reduction = f - f_star
112
+
113
+ # expected f is g.T @ p + 0.5 * p.T @ B @ p
114
+ pred_reduction = g.dot(d) - 0.5 * d.dot(Hu)
115
+
116
+ rho = reduction / (pred_reduction.clip(min=torch.finfo(g[0].dtype).tiny * 2))
117
+ return rho, f_star, reduction, pred_reduction
118
+
119
+ @torch.no_grad
120
+ def default_radius(
121
+ params: Sequence[torch.Tensor],
122
+ closure: Callable,
123
+ f: float,
124
+ g: torch.Tensor | TensorList,
125
+ H: LinearOperator | Callable,
126
+ d: torch.Tensor | TensorList,
127
+ trust_radius: float,
128
+ eta: float, # 0.0
129
+ nplus: float, # 3.5
130
+ nminus: float, # 0.25
131
+ rho_good: float, # 0.99
132
+ rho_bad: float, # 1e-4
133
+ boundary_tol: float | None,
134
+ init: float,
135
+ state: Mapping[str, Any],
136
+ settings: Mapping[str, Any],
137
+ radius_fn: Callable | None = generic_vector_norm,
138
+ check_overflow: bool = True,
139
+ # dynamic_nminus: bool=False,
140
+ ) -> tuple[float, bool]:
141
+
142
+ # when rho_bad < rho < eta, no update is made but trust region is not updated.
143
+ if eta > rho_bad:
144
+ warnings.warn(f"trust region eta={eta} is larger than rho_bad={rho_bad}, "
145
+ "this can lead to trust region getting stuck.")
146
+
147
+ if isinstance(g, torch.Tensor):
148
+ rho, f_star, _, _ = _get_rho(params=params, closure=closure, f=f, g=g, H=H, d=d) # pyright:ignore[reportArgumentType]
149
+ else:
150
+ rho, f_star, _, _ = _get_rho_tensorlist(params=params, closure=closure, f=f, g=g, Hvp=H, d=d) # pyright:ignore[reportArgumentType]
151
+
152
+ is_finite = math.isfinite(f_star)
153
+
154
+ # find boundary of current step
155
+ if radius_fn is None: d_radius = trust_radius
156
+ else: d_radius = radius_fn(d)
157
+
158
+ # failed step
159
+ if rho < rho_bad or not is_finite:
160
+ # if dynamic_nminus and rho > 0: nminus = nminus * max(rho, 1e-4)
161
+ trust_radius = d_radius*nminus
162
+
163
+ # very good step
164
+ elif rho > rho_good and is_finite:
165
+ if (boundary_tol is None) or (trust_radius-d_radius)/trust_radius < boundary_tol:
166
+ trust_radius = max(trust_radius, d_radius*nplus)
167
+
168
+ # prevent very small or large values
169
+ if check_overflow:
170
+ finfo = generic_finfo(g)
171
+ if trust_radius < finfo.tiny*2 or trust_radius > finfo.max/2:
172
+ trust_radius = init
173
+
174
+ # return new trust region and success boolean
175
+ return tofloat(trust_radius), rho > eta and is_finite
176
+
177
+
178
+ def fixed_radius(
179
+ params: Sequence[torch.Tensor],
180
+ closure: Callable,
181
+ f: float,
182
+ g: torch.Tensor,
183
+ H: LinearOperator,
184
+ d: torch.Tensor,
185
+ trust_radius: float,
186
+ eta: float, # 0.0
187
+ nplus: float, # 3.5
188
+ nminus: float, # 0.25
189
+ rho_good: float, # 0.99
190
+ rho_bad: float, # 1e-4
191
+ boundary_tol: float | None,
192
+ init: float,
193
+ state: Mapping[str, Any],
194
+ settings: Mapping[str, Any],
195
+ radius_fn: Callable | None = torch.linalg.vector_norm,
196
+ ) -> tuple[float, bool]:
197
+ return init, True
198
+
199
+ _RADIUS_KEYS = Literal['default', 'fixed']
200
+ _RADIUS_STRATEGIES: dict[_RADIUS_KEYS, _RadiusStrategy] = {
201
+ "default": default_radius,
202
+ "fixed": fixed_radius,
203
+ # "dynamic": partial(default_radius, dynamic_nminus=True)
204
+ }
205
+
206
+ class TrustRegionBase(Module, ABC):
207
+ def __init__(
208
+ self,
209
+ defaults: dict | None,
210
+ hess_module: Chainable,
211
+ # suggested default values:
212
+ # Gould, Nicholas IM, et al. "Sensitivity of trust-region algorithms to their parameters." 4OR 3.3 (2005): 227-241.
213
+ # which I found from https://github.com/patrick-kidger/optimistix/blob/c1dad7e75fc35bd5a4977ac3a872991e51e83d2c/optimistix/_solver/trust_region.py#L113-200
214
+ eta: float, # 0.0
215
+ nplus: float, # 3.5
216
+ nminus: float, # 0.25
217
+ rho_good: float, # 0.99
218
+ rho_bad: float, # 1e-4
219
+ boundary_tol: float | None, # None or 1e-1
220
+ init: float, # 1
221
+ max_attempts: int, # 10
222
+ radius_strategy: _RadiusStrategy | _RADIUS_KEYS, # "default"
223
+ radius_fn: Callable | None, # torch.linalg.vector_norm
224
+ update_freq: int = 1,
225
+ inner: Chainable | None = None,
226
+ ):
227
+ if isinstance(radius_strategy, str): radius_strategy = _RADIUS_STRATEGIES[radius_strategy]
228
+ if defaults is None: defaults = {}
229
+
230
+ safe_dict_update_(
231
+ defaults,
232
+ dict(eta=eta, nplus=nplus, nminus=nminus, rho_good=rho_good, rho_bad=rho_bad, init=init,
233
+ update_freq=update_freq, max_attempts=max_attempts, radius_strategy=radius_strategy,
234
+ boundary_tol=boundary_tol)
235
+ )
236
+
237
+ super().__init__(defaults)
238
+
239
+ self._radius_fn = radius_fn
240
+ self.set_child('hess_module', hess_module)
241
+
242
+ if inner is not None:
243
+ self.set_child('inner', inner)
244
+
245
+ @abstractmethod
246
+ def trust_solve(
247
+ self,
248
+ f: float,
249
+ g: torch.Tensor,
250
+ H: LinearOperator,
251
+ radius: float,
252
+ params: list[torch.Tensor],
253
+ closure: Callable,
254
+ settings: Mapping[str, Any],
255
+ ) -> torch.Tensor:
256
+ """Solve Hx=g with a trust region penalty/bound defined by `radius`"""
257
+ ... # pylint:disable=unnecessary-ellipsis
258
+
259
+ def trust_region_update(self, var: Var, H: LinearOperator | None) -> None:
260
+ """updates the state of this module after H or B have been updated, if necessary"""
261
+
262
+ def trust_region_apply(self, var: Var, tensors:list[torch.Tensor], H: LinearOperator | None) -> Var:
263
+ """Solves the trust region subproblem and outputs ``Var`` with the solution direction."""
264
+ assert H is not None
265
+
266
+ params = TensorList(var.params)
267
+ settings = self.settings[params[0]]
268
+ g = _flatten_tensors(tensors)
269
+
270
+ max_attempts = settings['max_attempts']
271
+
272
+ # loss at x_0
273
+ loss = var.loss
274
+ closure = var.closure
275
+ if closure is None: raise RuntimeError("Trust region requires closure")
276
+ if loss is None: loss = var.get_loss(False)
277
+ loss = tofloat(loss)
278
+
279
+ # trust region step and update
280
+ success = False
281
+ d = None
282
+ while not success:
283
+ max_attempts -= 1
284
+ if max_attempts < 0: break
285
+
286
+ trust_radius = self.global_state.get('trust_radius', settings['init'])
287
+
288
+ # solve Hx=g
289
+ d = self.trust_solve(f=loss, g=g, H=H, radius=trust_radius, params=params, closure=closure, settings=settings)
290
+
291
+ # update trust radius
292
+ radius_strategy: _RadiusStrategy = settings['radius_strategy']
293
+ self.global_state["trust_radius"], success = radius_strategy(
294
+ params=params,
295
+ closure=closure,
296
+ d=d,
297
+ f=loss,
298
+ g=g,
299
+ H=H,
300
+ trust_radius=trust_radius,
301
+
302
+ eta=settings["eta"],
303
+ nplus=settings["nplus"],
304
+ nminus=settings["nminus"],
305
+ rho_good=settings["rho_good"],
306
+ rho_bad=settings["rho_bad"],
307
+ boundary_tol=settings["boundary_tol"],
308
+ init=settings["init"],
309
+
310
+ state=self.global_state,
311
+ settings=settings,
312
+ radius_fn=self._radius_fn,
313
+ )
314
+
315
+ assert d is not None
316
+ if success: var.update = vec_to_tensors(d, params)
317
+ else: var.update = params.zeros_like()
318
+
319
+ return var
320
+
321
+
322
+ @final
323
+ @torch.no_grad
324
+ def update(self, var):
325
+ step = self.global_state.get('step', 0)
326
+ self.global_state['step'] = step + 1
327
+
328
+ if step % self.defaults["update_freq"] == 0:
329
+
330
+ hessian_module = self.children['hess_module']
331
+ hessian_module.update(var)
332
+ H = hessian_module.get_H(var)
333
+ self.global_state["H"] = H
334
+
335
+ self.trust_region_update(var, H=H)
336
+
337
+
338
+ @final
339
+ @torch.no_grad
340
+ def apply(self, var):
341
+ H = self.global_state.get('H', None)
342
+
343
+ # -------------------------------- inner step -------------------------------- #
344
+ update = var.get_update()
345
+ if 'inner' in self.children:
346
+ update = apply_transform(self.children['inner'], update, params=var.params, grads=var.grad, var=var)
347
+
348
+ # ----------------------------------- apply ---------------------------------- #
349
+ return self.trust_region_apply(var=var, tensors=update, H=H)
350
+
@@ -0,0 +1 @@
1
+ from .svrg import SVRG
@@ -0,0 +1,208 @@
1
+ import warnings
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ from ...core.module import Module
7
+ from ...utils import tofloat
8
+
9
+
10
+ def _reset_except_self(optimizer, var, self: Module):
11
+ for m in optimizer.unrolled_modules:
12
+ if m is not self:
13
+ m.reset()
14
+
15
+ class SVRG(Module):
16
+ """Stochastic variance reduced gradient method (SVRG).
17
+
18
+ To use, put SVRG as the first module, it can be used with any other modules.
19
+ To reduce variance of a gradient estimator, put the gradient estimator before SVRG.
20
+
21
+ First it uses first ``accum_steps`` batches to compute full gradient at initial
22
+ parameters using gradient accumulation, the model will not be updated during this.
23
+
24
+ Then it performs ``svrg_steps`` SVRG steps, each requires two forward and backward passes.
25
+
26
+ After ``svrg_steps``, it goes back to full gradient computation step step.
27
+
28
+ As an alternative to gradient accumulation you can pass "full_closure" argument to the ``step`` method,
29
+ which should compute full gradients, set them to ``.grad`` attributes of the parameters,
30
+ and return full loss.
31
+
32
+ Args:
33
+ svrg_steps (int): number of steps before calculating full gradient. This can be set to length of the dataloader.
34
+ accum_steps (int | None, optional):
35
+ number of steps to accumulate the gradient for. Not used if "full_closure" is passed to the ``step`` method. If None, uses value of ``svrg_steps``. Defaults to None.
36
+ reset_before_accum (bool, optional):
37
+ whether to reset all other modules when re-calculating full gradient. Defaults to True.
38
+ svrg_loss (bool, optional):
39
+ whether to replace loss with SVRG loss (calculated by same formula as SVRG gradient). Defaults to True.
40
+ alpha (float, optional):
41
+ multiplier to ``g_full(x_0) - g_batch(x_0)`` term, can be annealed linearly from 1 to 0 as suggested in https://arxiv.org/pdf/2311.05589#page=6
42
+
43
+ ## Examples:
44
+ SVRG-LBFGS
45
+ ```python
46
+ opt = tz.Modular(
47
+ model.parameters(),
48
+ tz.m.SVRG(len(dataloader)),
49
+ tz.m.LBFGS(),
50
+ tz.m.Backtracking(),
51
+ )
52
+ ```
53
+
54
+ For extra variance reduction one can use Online versions of algorithms, although it won't always help.
55
+ ```python
56
+ opt = tz.Modular(
57
+ model.parameters(),
58
+ tz.m.SVRG(len(dataloader)),
59
+ tz.m.Online(tz.m.LBFGS()),
60
+ tz.m.Backtracking(),
61
+ )
62
+
63
+ Variance reduction can also be applied to gradient estimators.
64
+ ```python
65
+ opt = tz.Modular(
66
+ model.parameters(),
67
+ tz.m.SPSA(),
68
+ tz.m.SVRG(100),
69
+ tz.m.LR(1e-2),
70
+ )
71
+ ```
72
+ ## Notes
73
+
74
+ The SVRG gradient is computed as ``g_b(x) - alpha * g_b(x_0) - g_f(x0.)``, where:
75
+ - ``x`` is current parameters
76
+ - ``x_0`` is initial parameters, where full gradient was computed
77
+ - ``g_b`` refers to mini-batch gradient at ``x`` or ``x_0``
78
+ - ``g_f`` refers to full gradient at ``x_0``.
79
+
80
+ The SVRG loss is computed using the same formula.
81
+ """
82
+ def __init__(self, svrg_steps: int, accum_steps: int | None = None, reset_before_accum:bool=True, svrg_loss:bool=True, alpha:float=1):
83
+ defaults = dict(svrg_steps = svrg_steps, accum_steps=accum_steps, reset_before_accum=reset_before_accum, svrg_loss=svrg_loss, alpha=alpha)
84
+ super().__init__(defaults)
85
+
86
+ @torch.no_grad
87
+ def step(self, var):
88
+ params = var.params
89
+ closure = var.closure
90
+ assert closure is not None
91
+
92
+ if "full_grad" not in self.global_state:
93
+
94
+ # -------------------------- calculate full gradient ------------------------- #
95
+ if "full_closure" in var.storage:
96
+ full_closure = var.storage['full_closure']
97
+ with torch.enable_grad():
98
+ full_loss = full_closure()
99
+ if all(p.grad is None for p in params):
100
+ warnings.warn("all gradients are None after evaluating full_closure.")
101
+
102
+ full_grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
103
+ self.global_state["full_loss"] = full_loss
104
+ self.global_state["full_grad"] = full_grad
105
+ self.global_state['x_0'] = [p.clone() for p in params]
106
+
107
+ # current batch will be used for svrg update
108
+
109
+ else:
110
+ # accumulate gradients over n steps
111
+ accum_steps = self.defaults['accum_steps']
112
+ if accum_steps is None: accum_steps = self.defaults['svrg_steps']
113
+
114
+ current_accum_step = self.global_state.get('current_accum_step', 0) + 1
115
+ self.global_state['current_accum_step'] = current_accum_step
116
+
117
+ # accumulate grads
118
+ accumulator = self.get_state(params, 'accumulator')
119
+ grad = var.get_grad()
120
+ torch._foreach_add_(accumulator, grad)
121
+
122
+ # accumulate loss
123
+ loss_accumulator = self.global_state.get('loss_accumulator', 0)
124
+ loss_accumulator += tofloat(var.loss)
125
+ self.global_state['loss_accumulator'] = loss_accumulator
126
+
127
+ # on nth step, use the accumulated gradient
128
+ if current_accum_step >= accum_steps:
129
+ torch._foreach_div_(accumulator, accum_steps)
130
+ self.global_state["full_grad"] = accumulator
131
+ self.global_state["full_loss"] = loss_accumulator / accum_steps
132
+
133
+ self.global_state['x_0'] = [p.clone() for p in params]
134
+ self.clear_state_keys('accumulator')
135
+ del self.global_state['current_accum_step']
136
+
137
+ # otherwise skip update until enough grads are accumulated
138
+ else:
139
+ var.update = None
140
+ var.stop = True
141
+ var.skip_update = True
142
+ return var
143
+
144
+
145
+ svrg_steps = self.defaults['svrg_steps']
146
+ current_svrg_step = self.global_state.get('current_svrg_step', 0) + 1
147
+ self.global_state['current_svrg_step'] = current_svrg_step
148
+
149
+ # --------------------------- SVRG gradient closure -------------------------- #
150
+ x0 = self.global_state['x_0']
151
+ gf_x0 = self.global_state["full_grad"]
152
+ ff_x0 = self.global_state['full_loss']
153
+ use_svrg_loss = self.defaults['svrg_loss']
154
+ alpha = self.get_settings(params, 'alpha')
155
+ alpha_0 = alpha[0]
156
+ if all(a == 1 for a in alpha): alpha = None
157
+
158
+ def svrg_closure(backward=True):
159
+ # g_b(x) - α * (g_f(x_0) - g_b(x_0)) and same for loss
160
+ with torch.no_grad():
161
+ x = [p.clone() for p in params]
162
+
163
+ if backward:
164
+ # f and g at x
165
+ with torch.enable_grad(): fb_x = closure()
166
+ gb_x = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
167
+
168
+ # f and g at x_0
169
+ torch._foreach_copy_(params, x0)
170
+ with torch.enable_grad(): fb_x0 = closure()
171
+ gb_x0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
172
+ torch._foreach_copy_(params, x)
173
+
174
+ # g_svrg = gb_x - alpha * (gf_x0 - gb_x0)
175
+ correction = torch._foreach_sub(gb_x0, gf_x0)
176
+ if alpha is not None: torch._foreach_mul_(correction, alpha)
177
+ g_svrg = torch._foreach_sub(gb_x, correction)
178
+
179
+ f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
180
+ for p, g in zip(params, g_svrg):
181
+ p.grad = g
182
+
183
+ if use_svrg_loss: return f_svrg
184
+ return fb_x
185
+
186
+ # no backward
187
+ if use_svrg_loss:
188
+ fb_x = closure(False)
189
+ torch._foreach_copy_(params, x0)
190
+ fb_x0 = closure(False)
191
+ torch._foreach_copy_(params, x)
192
+ f_svrg = fb_x - alpha_0 * (fb_x0 - ff_x0)
193
+ return f_svrg
194
+
195
+ return closure(False)
196
+
197
+ var.closure = svrg_closure
198
+
199
+ # --- after svrg_steps steps reset so that new full gradient is calculated on next step --- #
200
+ if current_svrg_step >= svrg_steps:
201
+ del self.global_state['current_svrg_step']
202
+ del self.global_state['full_grad']
203
+ del self.global_state['full_loss']
204
+ del self.global_state['x_0']
205
+ if self.defaults['reset_before_accum']:
206
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
207
+
208
+ return var