torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,11 +1,14 @@
1
- from typing import Literal, overload
1
+ import warnings
2
+ import math
3
+ from typing import Literal, cast
4
+ from operator import itemgetter
2
5
  import torch
3
6
 
4
- from ...utils import TensorList, as_tensorlist, NumberList
7
+ from ...core import Chainable, Module, apply_transform
8
+ from ...utils import TensorList, as_tensorlist, tofloat
5
9
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
6
-
7
- from ...core import Chainable, apply_transform, Module
8
- from ...utils.linalg.solve import cg, steihaug_toint_cg, minres
10
+ from ...utils.linalg.solve import cg, minres, find_within_trust_radius
11
+ from ..trust_region.trust_region import default_radius
9
12
 
10
13
  class NewtonCG(Module):
11
14
  """Newton's method with a matrix-free conjugate gradient or minimial-residual solver.
@@ -88,20 +91,25 @@ class NewtonCG(Module):
88
91
  def __init__(
89
92
  self,
90
93
  maxiter: int | None = None,
91
- tol: float = 1e-4,
94
+ tol: float = 1e-8,
92
95
  reg: float = 1e-8,
93
96
  hvp_method: Literal["forward", "central", "autograd"] = "autograd",
94
97
  solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
95
98
  h: float = 1e-3,
99
+ miniter:int = 1,
96
100
  warm_start=False,
97
101
  inner: Chainable | None = None,
98
102
  ):
99
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, solver=solver, h=h, warm_start=warm_start)
103
+ defaults = locals().copy()
104
+ del defaults['self'], defaults['inner']
100
105
  super().__init__(defaults,)
101
106
 
102
107
  if inner is not None:
103
108
  self.set_child('inner', inner)
104
109
 
110
+ self._num_hvps = 0
111
+ self._num_hvps_last_step = 0
112
+
105
113
  @torch.no_grad
106
114
  def step(self, var):
107
115
  params = TensorList(var.params)
@@ -117,11 +125,13 @@ class NewtonCG(Module):
117
125
  h = settings['h']
118
126
  warm_start = settings['warm_start']
119
127
 
128
+ self._num_hvps_last_step = 0
120
129
  # ---------------------- Hessian vector product function --------------------- #
121
130
  if hvp_method == 'autograd':
122
131
  grad = var.get_grad(create_graph=True)
123
132
 
124
133
  def H_mm(x):
134
+ self._num_hvps_last_step += 1
125
135
  with torch.enable_grad():
126
136
  return TensorList(hvp(params, grad, x, retain_graph=True))
127
137
 
@@ -132,10 +142,12 @@ class NewtonCG(Module):
132
142
 
133
143
  if hvp_method == 'forward':
134
144
  def H_mm(x):
145
+ self._num_hvps_last_step += 1
135
146
  return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
136
147
 
137
148
  elif hvp_method == 'central':
138
149
  def H_mm(x):
150
+ self._num_hvps_last_step += 1
139
151
  return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
140
152
 
141
153
  else:
@@ -153,26 +165,28 @@ class NewtonCG(Module):
153
165
  if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
154
166
 
155
167
  if solver == 'cg':
156
- x = cg(A_mm=H_mm, b=b, x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
168
+ d, _ = cg(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, miniter=self.defaults["miniter"],reg=reg)
157
169
 
158
170
  elif solver == 'minres':
159
- x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
171
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
160
172
 
161
173
  elif solver == 'minres_npc':
162
- x = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
174
+ d = minres(A_mm=H_mm, b=b, x0=x0, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
163
175
 
164
176
  else:
165
177
  raise ValueError(f"Unknown solver {solver}")
166
178
 
167
179
  if warm_start:
168
180
  assert x0 is not None
169
- x0.copy_(x)
181
+ x0.copy_(d)
182
+
183
+ var.update = d
170
184
 
171
- var.update = x
185
+ self._num_hvps += self._num_hvps_last_step
172
186
  return var
173
187
 
174
188
 
175
- class TruncatedNewtonCG(Module):
189
+ class NewtonCGSteihaug(Module):
176
190
  """Trust region Newton's method with a matrix-free Steihaug-Toint conjugate gradient or MINRES solver.
177
191
 
178
192
  This optimizer implements Newton's method using a matrix-free conjugate
@@ -245,49 +259,61 @@ class TruncatedNewtonCG(Module):
245
259
  def __init__(
246
260
  self,
247
261
  maxiter: int | None = None,
248
- eta: float= 1e-6,
249
- nplus: float = 2,
262
+ eta: float= 0.0,
263
+ nplus: float = 3.5,
250
264
  nminus: float = 0.25,
265
+ rho_good: float = 0.99,
266
+ rho_bad: float = 1e-4,
251
267
  init: float = 1,
252
- tol: float = 1e-4,
268
+ tol: float = 1e-8,
253
269
  reg: float = 1e-8,
254
- hvp_method: Literal["forward", "central", "autograd"] = "autograd",
255
- solver: Literal['cg', 'minres', 'minres_npc'] = 'cg',
270
+ hvp_method: Literal["forward", "central"] = "forward",
271
+ solver: Literal['cg', "minres"] = 'cg',
256
272
  h: float = 1e-3,
257
- max_attempts: int = 10,
273
+ max_attempts: int = 100,
274
+ max_history: int = 100,
275
+ boundary_tol: float = 1e-1,
276
+ miniter: int = 1,
277
+ rms_beta: float | None = None,
278
+ adapt_tol: bool = True,
279
+ npc_terminate: bool = False,
258
280
  inner: Chainable | None = None,
259
281
  ):
260
- defaults = dict(tol=tol, maxiter=maxiter, reg=reg, hvp_method=hvp_method, h=h, eta=eta, nplus=nplus, nminus=nminus, init=init, max_attempts=max_attempts, solver=solver)
282
+ defaults = locals().copy()
283
+ del defaults['self'], defaults['inner']
261
284
  super().__init__(defaults,)
262
285
 
263
286
  if inner is not None:
264
287
  self.set_child('inner', inner)
265
288
 
289
+ self._num_hvps = 0
290
+ self._num_hvps_last_step = 0
291
+
266
292
  @torch.no_grad
267
293
  def step(self, var):
268
294
  params = TensorList(var.params)
269
295
  closure = var.closure
270
296
  if closure is None: raise RuntimeError('NewtonCG requires closure')
271
297
 
272
- settings = self.settings[params[0]]
273
- tol = settings['tol']
274
- reg = settings['reg']
275
- maxiter = settings['maxiter']
276
- hvp_method = settings['hvp_method']
277
- h = settings['h']
278
- max_attempts = settings['max_attempts']
279
- solver = settings['solver'].lower().strip()
298
+ tol = self.defaults['tol'] * self.global_state.get('tol_mul', 1)
299
+ solver = self.defaults['solver'].lower().strip()
280
300
 
281
- eta = settings['eta']
282
- nplus = settings['nplus']
283
- nminus = settings['nminus']
284
- init = settings['init']
301
+ (reg, maxiter, hvp_method, h, max_attempts, boundary_tol,
302
+ eta, nplus, nminus, rho_good, rho_bad, init, npc_terminate,
303
+ miniter, max_history, adapt_tol) = itemgetter(
304
+ "reg", "maxiter", "hvp_method", "h", "max_attempts", "boundary_tol",
305
+ "eta", "nplus", "nminus", "rho_good", "rho_bad", "init", "npc_terminate",
306
+ "miniter", "max_history", "adapt_tol",
307
+ )(self.defaults)
308
+
309
+ self._num_hvps_last_step = 0
285
310
 
286
311
  # ---------------------- Hessian vector product function --------------------- #
287
312
  if hvp_method == 'autograd':
288
313
  grad = var.get_grad(create_graph=True)
289
314
 
290
315
  def H_mm(x):
316
+ self._num_hvps_last_step += 1
291
317
  with torch.enable_grad():
292
318
  return TensorList(hvp(params, grad, x, retain_graph=True))
293
319
 
@@ -298,77 +324,112 @@ class TruncatedNewtonCG(Module):
298
324
 
299
325
  if hvp_method == 'forward':
300
326
  def H_mm(x):
327
+ self._num_hvps_last_step += 1
301
328
  return TensorList(hvp_fd_forward(closure, params, x, h=h, g_0=grad, normalize=True)[1])
302
329
 
303
330
  elif hvp_method == 'central':
304
331
  def H_mm(x):
332
+ self._num_hvps_last_step += 1
305
333
  return TensorList(hvp_fd_central(closure, params, x, h=h, normalize=True)[1])
306
334
 
307
335
  else:
308
336
  raise ValueError(hvp_method)
309
337
 
310
338
 
311
- # -------------------------------- inner step -------------------------------- #
339
+ # ------------------------- update RMS preconditioner ------------------------ #
312
340
  b = var.get_update()
341
+ P_mm = None
342
+ rms_beta = self.defaults["rms_beta"]
343
+ if rms_beta is not None:
344
+ exp_avg_sq = self.get_state(params, "exp_avg_sq", init=b, cls=TensorList)
345
+ exp_avg_sq.mul_(rms_beta).addcmul(b, b, value=1-rms_beta)
346
+ exp_avg_sq_sqrt = exp_avg_sq.sqrt().add_(1e-8)
347
+ def _P_mm(x):
348
+ return x / exp_avg_sq_sqrt
349
+ P_mm = _P_mm
350
+
351
+ # -------------------------------- inner step -------------------------------- #
313
352
  if 'inner' in self.children:
314
353
  b = apply_transform(self.children['inner'], b, params=params, grads=grad, var=var)
315
354
  b = as_tensorlist(b)
316
355
 
317
- # ---------------------------------- run cg ---------------------------------- #
356
+ # ------------------------------- trust region ------------------------------- #
318
357
  success = False
319
- x = None
358
+ d = None
359
+ x0 = [p.clone() for p in params]
360
+ solution = None
361
+
320
362
  while not success:
321
363
  max_attempts -= 1
322
364
  if max_attempts < 0: break
323
365
 
324
- trust_region = self.global_state.get('trust_region', init)
325
- if trust_region < 1e-8 or trust_region > 1e8:
326
- trust_region = self.global_state['trust_region'] = init
327
-
328
- if solver == 'cg':
329
- x = steihaug_toint_cg(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg)
330
-
331
- elif solver == 'minres':
332
- x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=False)
333
-
334
- elif solver == 'minres_npc':
335
- x = minres(A_mm=H_mm, b=b, trust_region=trust_region, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=True)
336
-
337
- else:
338
- raise ValueError(f"unknown solver {solver}")
339
-
340
- # ------------------------------- trust region ------------------------------- #
341
- Hx = H_mm(x)
342
- pred_reduction = b.dot(x) - 0.5 * x.dot(Hx)
343
-
344
- params -= x
345
- loss_star = closure(False)
346
- params += x
347
- reduction = var.get_loss(False) - loss_star
348
-
349
- rho = reduction / (pred_reduction.clip(min=1e-8))
350
-
351
- # failed step
352
- if rho < 0.25:
353
- self.global_state['trust_region'] = trust_region * nminus
354
-
355
- # very good step
356
- elif rho > 0.75:
357
- diff = trust_region - x.abs()
358
- if (diff.global_min() / trust_region) > 1e-4: # hits boundary
359
- self.global_state['trust_region'] = trust_region * nplus
360
-
361
- # if the ratio is high enough then accept the proposed step
362
- if rho > eta:
363
- success = True
366
+ trust_radius = self.global_state.get('trust_radius', init)
367
+
368
+ # -------------- make sure trust radius isn't too small or large ------------- #
369
+ finfo = torch.finfo(x0[0].dtype)
370
+ if trust_radius < finfo.tiny * 2:
371
+ trust_radius = self.global_state['trust_radius'] = init
372
+ if adapt_tol:
373
+ self.global_state["tol_mul"] = self.global_state.get("tol_mul", 1) * 0.1
374
+
375
+ elif trust_radius > finfo.max / 2:
376
+ trust_radius = self.global_state['trust_radius'] = init
377
+
378
+ # ----------------------------------- solve ---------------------------------- #
379
+ d = None
380
+ if solution is not None and solution.history is not None:
381
+ d = find_within_trust_radius(solution.history, trust_radius)
382
+
383
+ if d is None:
384
+ if solver == 'cg':
385
+ d, solution = cg(
386
+ A_mm=H_mm,
387
+ b=b,
388
+ tol=tol,
389
+ maxiter=maxiter,
390
+ reg=reg,
391
+ trust_radius=trust_radius,
392
+ miniter=miniter,
393
+ npc_terminate=npc_terminate,
394
+ history_size=max_history,
395
+ P_mm=P_mm,
396
+ )
397
+
398
+ elif solver == 'minres':
399
+ d = minres(A_mm=H_mm, b=b, trust_radius=trust_radius, tol=tol, maxiter=maxiter, reg=reg, npc_terminate=npc_terminate)
400
+
401
+ else:
402
+ raise ValueError(f"unknown solver {solver}")
403
+
404
+ # ---------------------------- update trust radius --------------------------- #
405
+ self.global_state["trust_radius"], success = default_radius(
406
+ params=params,
407
+ closure=closure,
408
+ f=tofloat(var.get_loss(False)),
409
+ g=b,
410
+ H=H_mm,
411
+ d=d,
412
+ trust_radius=trust_radius,
413
+ eta=eta,
414
+ nplus=nplus,
415
+ nminus=nminus,
416
+ rho_good=rho_good,
417
+ rho_bad=rho_bad,
418
+ boundary_tol=boundary_tol,
419
+
420
+ init=init, # init isn't used because check_overflow=False
421
+ state=self.global_state, # not used
422
+ settings=self.defaults, # not used
423
+ check_overflow=False, # this is checked manually to adapt tolerance
424
+ )
364
425
 
365
- assert x is not None
426
+ # --------------------------- assign new direction --------------------------- #
427
+ assert d is not None
366
428
  if success:
367
- var.update = x
429
+ var.update = d
368
430
 
369
431
  else:
370
432
  var.update = params.zeros_like()
371
433
 
372
- return var
373
-
374
-
434
+ self._num_hvps += self._num_hvps_last_step
435
+ return var
@@ -1,2 +1,2 @@
1
1
  from .laplacian import LaplacianSmoothing
2
- from .gaussian import GaussianHomotopy
2
+ from .sampling import GradientSampling
@@ -0,0 +1,300 @@
1
+ import math
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import Callable, Sequence
4
+ from contextlib import nullcontext
5
+ from functools import partial
6
+ from typing import Literal, cast
7
+
8
+ import torch
9
+
10
+ from ...core import Chainable, Modular, Module, Var
11
+ from ...core.reformulation import Reformulation
12
+ from ...utils import Distributions, NumberList, TensorList
13
+ from ..termination import TerminationCriteriaBase, make_termination_criteria
14
+
15
+
16
+ def _reset_except_self(optimizer: Modular, var: Var, self: Module):
17
+ for m in optimizer.unrolled_modules:
18
+ if m is not self:
19
+ m.reset()
20
+
21
+
22
+ class GradientSampling(Reformulation):
23
+ """Samples and aggregates gradients and values at perturbed points.
24
+
25
+ This module can be used for gaussian homotopy and gradient sampling methods.
26
+
27
+ Args:
28
+ modules (Chainable | None, optional):
29
+ modules that will be optimizing the modified objective.
30
+ if None, returns gradient of the modified objective as the update. Defaults to None.
31
+ sigma (float, optional): initial magnitude of the perturbations. Defaults to 1.
32
+ n (int, optional): number of perturbations per step. Defaults to 100.
33
+ aggregate (str, optional):
34
+ how to aggregate values and gradients
35
+ - "mean" - uses mean of the gradients, as in gaussian homotopy.
36
+ - "max" - uses element-wise maximum of the gradients.
37
+ - "min" - uses element-wise minimum of the gradients.
38
+ - "min-norm" - picks gradient with the lowest norm.
39
+
40
+ Defaults to 'mean'.
41
+ distribution (Distributions, optional): distribution for random perturbations. Defaults to 'gaussian'.
42
+ include_x0 (bool, optional): whether to include gradient at un-perturbed point. Defaults to True.
43
+ fixed (bool, optional):
44
+ if True, perturbations do not get replaced by new random perturbations until termination criteria is satisfied. Defaults to True.
45
+ pre_generate (bool, optional):
46
+ if True, perturbations are pre-generated before each step.
47
+ This requires more memory to store all of them,
48
+ but ensures they do not change when closure is evaluated multiple times.
49
+ Defaults to True.
50
+ termination (TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None, optional):
51
+ a termination criteria module, sigma will be multiplied by ``decay`` when termination criteria is satisfied,
52
+ and new perturbations will be generated if ``fixed``. Defaults to None.
53
+ decay (float, optional): sigma multiplier on termination criteria. Defaults to 2/3.
54
+ reset_on_termination (bool, optional): whether to reset states of all other modules on termination. Defaults to True.
55
+ sigma_strategy (str | None, optional):
56
+ strategy for adapting sigma. If condition is satisfied, sigma is multiplied by ``sigma_nplus``,
57
+ otherwise it is multiplied by ``sigma_nminus``.
58
+ - "grad-norm" - at least ``sigma_target`` gradients should have lower norm than at un-perturbed point.
59
+ - "value" - at least ``sigma_target`` values (losses) should be lower than at un-perturbed point.
60
+ - None - doesn't use adaptive sigma.
61
+
62
+ This introduces a side-effect to the closure, so it should be left at None of you use
63
+ trust region or line search to optimize the modified objective.
64
+ Defaults to None.
65
+ sigma_target (int, optional):
66
+ number of elements to satisfy the condition in ``sigma_strategy``. Defaults to 1.
67
+ sigma_nplus (float, optional): sigma multiplier when ``sigma_strategy`` condition is satisfied. Defaults to 4/3.
68
+ sigma_nminus (float, optional): sigma multiplier when ``sigma_strategy`` condition is not satisfied. Defaults to 2/3.
69
+ seed (int | None, optional): seed. Defaults to None.
70
+ """
71
+ def __init__(
72
+ self,
73
+ modules: Chainable | None = None,
74
+ sigma: float = 1.,
75
+ n:int = 100,
76
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = 'mean',
77
+ distribution: Distributions = 'gaussian',
78
+ include_x0: bool = True,
79
+
80
+ fixed: bool=True,
81
+ pre_generate: bool = True,
82
+ termination: TerminationCriteriaBase | Sequence[TerminationCriteriaBase] | None = None,
83
+ decay: float = 2/3,
84
+ reset_on_termination: bool = True,
85
+
86
+ sigma_strategy: Literal['grad-norm', 'value'] | None = None,
87
+ sigma_target: int | float = 0.2,
88
+ sigma_nplus: float = 4/3,
89
+ sigma_nminus: float = 2/3,
90
+
91
+ seed: int | None = None,
92
+ ):
93
+
94
+ defaults = dict(sigma=sigma, n=n, aggregate=aggregate, distribution=distribution, seed=seed, include_x0=include_x0, fixed=fixed, decay=decay, reset_on_termination=reset_on_termination, sigma_strategy=sigma_strategy, sigma_target=sigma_target, sigma_nplus=sigma_nplus, sigma_nminus=sigma_nminus, pre_generate=pre_generate)
95
+ super().__init__(defaults, modules)
96
+
97
+ if termination is not None:
98
+ self.set_child('termination', make_termination_criteria(extra=termination))
99
+
100
+ @torch.no_grad
101
+ def pre_step(self, var):
102
+ params = TensorList(var.params)
103
+
104
+ fixed = self.defaults['fixed']
105
+
106
+ # check termination criteria
107
+ if 'termination' in self.children:
108
+ termination = cast(TerminationCriteriaBase, self.children['termination'])
109
+ if termination.should_terminate(var):
110
+
111
+ # decay sigmas
112
+ states = [self.state[p] for p in params]
113
+ settings = [self.settings[p] for p in params]
114
+
115
+ for state, setting in zip(states, settings):
116
+ if 'sigma' not in state: state['sigma'] = setting['sigma']
117
+ state['sigma'] *= setting['decay']
118
+
119
+ # reset on sigmas decay
120
+ if self.defaults['reset_on_termination']:
121
+ var.post_step_hooks.append(partial(_reset_except_self, self=self))
122
+
123
+ # clear perturbations
124
+ self.global_state.pop('perts', None)
125
+
126
+ # pre-generate perturbations if not already pre-generated or not fixed
127
+ if self.defaults['pre_generate'] and (('perts' not in self.global_state) or (not fixed)):
128
+ states = [self.state[p] for p in params]
129
+ settings = [self.settings[p] for p in params]
130
+
131
+ n = self.defaults['n'] - self.defaults['include_x0']
132
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
133
+
134
+ perts = [params.sample_like(self.defaults['distribution'], generator=generator) for _ in range(n)]
135
+
136
+ self.global_state['perts'] = perts
137
+
138
+ @torch.no_grad
139
+ def closure(self, backward, closure, params, var):
140
+ params = TensorList(params)
141
+ loss_agg = None
142
+ grad_agg = None
143
+
144
+ states = [self.state[p] for p in params]
145
+ settings = [self.settings[p] for p in params]
146
+ sigma_inits = [s['sigma'] for s in settings]
147
+ sigmas = [s.setdefault('sigma', si) for s, si in zip(states, sigma_inits)]
148
+
149
+ include_x0 = self.defaults['include_x0']
150
+ pre_generate = self.defaults['pre_generate']
151
+ aggregate: Literal['mean', 'max', 'min', 'min-norm', 'min-value'] = self.defaults['aggregate']
152
+ sigma_strategy: Literal['grad-norm', 'value'] | None = self.defaults['sigma_strategy']
153
+ distribution = self.defaults['distribution']
154
+ generator = self.get_generator(params[0].device, self.defaults['seed'])
155
+
156
+
157
+ n_finite = 0
158
+ n_good = 0
159
+ f_0 = None; g_0 = None
160
+
161
+ # evaluate at x_0
162
+ if include_x0:
163
+ f_0 = cast(torch.Tensor, var.get_loss(backward=backward))
164
+
165
+ isfinite = math.isfinite(f_0)
166
+ if isfinite:
167
+ n_finite += 1
168
+ loss_agg = f_0
169
+
170
+ if backward:
171
+ g_0 = var.get_grad()
172
+ if isfinite: grad_agg = g_0
173
+
174
+ # evaluate at x_0 + p for each perturbation
175
+ if pre_generate:
176
+ perts = self.global_state['perts']
177
+ else:
178
+ perts = [None] * (self.defaults['n'] - include_x0)
179
+
180
+ x_0 = [p.clone() for p in params]
181
+
182
+ for pert in perts:
183
+ loss = None; grad = None
184
+
185
+ # generate if not pre-generated
186
+ if pert is None:
187
+ pert = params.sample_like(distribution, generator=generator)
188
+
189
+ # add perturbation and evaluate
190
+ pert = pert * sigmas
191
+ torch._foreach_add_(params, pert)
192
+
193
+ with torch.enable_grad() if backward else nullcontext():
194
+ loss = closure(backward)
195
+
196
+ if math.isfinite(loss):
197
+ n_finite += 1
198
+
199
+ # add loss
200
+ if loss_agg is None:
201
+ loss_agg = loss
202
+ else:
203
+ if aggregate == 'mean':
204
+ loss_agg += loss
205
+
206
+ elif (aggregate=='min') or (aggregate=='min-value') or (aggregate=='min-norm' and not backward):
207
+ loss_agg = loss_agg.clamp(max=loss)
208
+
209
+ elif aggregate == 'max':
210
+ loss_agg = loss_agg.clamp(min=loss)
211
+
212
+ # add grad
213
+ if backward:
214
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
215
+ if grad_agg is None:
216
+ grad_agg = grad
217
+ else:
218
+ if aggregate == 'mean':
219
+ torch._foreach_add_(grad_agg, grad)
220
+
221
+ elif aggregate == 'min':
222
+ grad_agg_abs = torch._foreach_abs(grad_agg)
223
+ torch._foreach_minimum_(grad_agg_abs, torch._foreach_abs(grad))
224
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
225
+
226
+ elif aggregate == 'max':
227
+ grad_agg_abs = torch._foreach_abs(grad_agg)
228
+ torch._foreach_maximum_(grad_agg_abs, torch._foreach_abs(grad))
229
+ grad_agg = [g_abs.copysign(g) for g_abs, g in zip(grad_agg_abs, grad_agg)]
230
+
231
+ elif aggregate == 'min-norm':
232
+ if TensorList(grad).global_vector_norm() < TensorList(grad_agg).global_vector_norm():
233
+ grad_agg = grad
234
+ loss_agg = loss
235
+
236
+ elif aggregate == 'min-value':
237
+ if loss < loss_agg:
238
+ grad_agg = grad
239
+ loss_agg = loss
240
+
241
+ # undo perturbation
242
+ torch._foreach_copy_(params, x_0)
243
+
244
+ # adaptive sigma
245
+ # by value
246
+ if sigma_strategy == 'value':
247
+ if f_0 is None:
248
+ with torch.enable_grad() if backward else nullcontext():
249
+ f_0 = closure(False)
250
+
251
+ if loss < f_0:
252
+ n_good += 1
253
+
254
+ # by gradient norm
255
+ elif sigma_strategy == 'grad-norm' and backward and math.isfinite(loss):
256
+ assert grad is not None
257
+ if g_0 is None:
258
+ with torch.enable_grad() if backward else nullcontext():
259
+ closure()
260
+ g_0 = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
261
+
262
+ if TensorList(grad).global_vector_norm() < TensorList(g_0).global_vector_norm():
263
+ n_good += 1
264
+
265
+ # update sigma if strategy is enabled
266
+ if sigma_strategy is not None:
267
+
268
+ sigma_target = self.defaults['sigma_target']
269
+ if isinstance(sigma_target, float):
270
+ sigma_target = int(max(1, n_finite * sigma_target))
271
+
272
+ if n_good >= sigma_target:
273
+ key = 'sigma_nplus'
274
+ else:
275
+ key = 'sigma_nminus'
276
+
277
+ for p in params:
278
+ self.state[p]['sigma'] *= self.settings[p][key]
279
+
280
+ # if no finite losses, just return inf
281
+ if n_finite == 0:
282
+ assert loss_agg is None and grad_agg is None
283
+ loss = torch.tensor(torch.inf, dtype=params[0].dtype, device=params[0].device)
284
+ grad = [torch.full_like(p, torch.inf) for p in params]
285
+ return loss, grad
286
+
287
+ assert loss_agg is not None
288
+
289
+ # no post processing needed when aggregate is 'max', 'min', 'min-norm', 'min-value'
290
+ if aggregate != 'mean':
291
+ return loss_agg, grad_agg
292
+
293
+ # on mean divide by number of evals
294
+ loss_agg /= n_finite
295
+
296
+ if backward:
297
+ assert grad_agg is not None
298
+ torch._foreach_div_(grad_agg, n_finite)
299
+
300
+ return loss_agg, grad_agg
@@ -1,2 +1,2 @@
1
1
  from .lr import LR, StepSize, Warmup, WarmupNormClip, RandomStepSize
2
- from .adaptive import PolyakStepSize, BarzilaiBorwein
2
+ from .adaptive import PolyakStepSize, BarzilaiBorwein, BBStab, AdGD