torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  """Use BFGS or maybe SR1."""
2
- from typing import Any, Literal
3
2
  from abc import ABC, abstractmethod
4
3
  from collections.abc import Mapping
4
+ from typing import Any, Literal
5
+
5
6
  import torch
6
7
 
7
- from ...core import Chainable, Module, Preconditioner, TensorwisePreconditioner
8
- from ...utils import TensorList, set_storage_
8
+ from ...core import Chainable, Module, TensorwiseTransform, Transform
9
+ from ...utils import TensorList, set_storage_, unpack_states
10
+
9
11
 
10
12
  def _safe_dict_update_(d1_:dict, d2:dict):
11
13
  inter = set(d1_.keys()).intersection(d2.keys())
@@ -17,14 +19,14 @@ def _maybe_lerp_(state, key, value: torch.Tensor, beta: float | None):
17
19
  elif state[key].shape != value.shape: state[key] = value
18
20
  else: state[key].lerp_(value, 1-beta)
19
21
 
20
- class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
22
+ class HessianUpdateStrategy(TensorwiseTransform, ABC):
21
23
  def __init__(
22
24
  self,
23
25
  defaults: dict | None = None,
24
26
  init_scale: float | Literal["auto"] = "auto",
25
27
  tol: float = 1e-10,
26
28
  tol_reset: bool = True,
27
- reset_interval: int | None = None,
29
+ reset_interval: int | None | Literal['auto'] = None,
28
30
  beta: float | None = None,
29
31
  update_freq: int = 1,
30
32
  scale_first: bool = True,
@@ -44,7 +46,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
44
46
  if ys != 0 and yy != 0: return yy/ys
45
47
  return 1
46
48
 
47
- def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor,inverse:bool, init_scale: Any):
49
+ def _reset_M_(self, M: torch.Tensor, s:torch.Tensor,y:torch.Tensor, inverse:bool, init_scale: Any, state:dict[str,Any]):
48
50
  set_storage_(M, torch.eye(M.size(-1), device=M.device, dtype=M.dtype))
49
51
  if init_scale == 'auto': init_scale = self._get_init_scale(s,y)
50
52
  if init_scale >= 1:
@@ -62,7 +64,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
62
64
  raise NotImplementedError
63
65
 
64
66
  @torch.no_grad
65
- def update_tensor(self, tensor, param, grad, state, settings):
67
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
66
68
  p = param.view(-1); g = tensor.view(-1)
67
69
  inverse = settings['inverse']
68
70
  M_key = 'H' if inverse else 'B'
@@ -73,6 +75,7 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
73
75
  tol = settings['tol']
74
76
  tol_reset = settings['tol_reset']
75
77
  reset_interval = settings['reset_interval']
78
+ if reset_interval == 'auto': reset_interval = tensor.numel() + 1
76
79
 
77
80
  if M is None:
78
81
  M = torch.eye(p.size(0), device=p.device, dtype=p.dtype)
@@ -81,10 +84,12 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
81
84
  else: M *= init_scale
82
85
 
83
86
  state[M_key] = M
87
+ state['f_prev'] = loss
84
88
  state['p_prev'] = p.clone()
85
89
  state['g_prev'] = g.clone()
86
90
  return
87
91
 
92
+ state['f'] = loss
88
93
  p_prev = state['p_prev']
89
94
  g_prev = state['g_prev']
90
95
  s: torch.Tensor = p - p_prev
@@ -93,13 +98,13 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
93
98
  state['g_prev'].copy_(g)
94
99
 
95
100
  if reset_interval is not None and step != 0 and step % reset_interval == 0:
96
- self._reset_M_(M, s, y, inverse, init_scale)
101
+ self._reset_M_(M, s, y, inverse, init_scale, state)
97
102
  return
98
103
 
99
104
  # tolerance on gradient difference to avoid exploding after converging
100
- elif y.abs().max() <= tol:
105
+ if y.abs().max() <= tol:
101
106
  # reset history
102
- if tol_reset: self._reset_M_(M, s, y, inverse, init_scale)
107
+ if tol_reset: self._reset_M_(M, s, y, inverse, init_scale, state)
103
108
  return
104
109
 
105
110
  if step == 1 and init_scale == 'auto':
@@ -117,8 +122,10 @@ class HessianUpdateStrategy(TensorwisePreconditioner, ABC):
117
122
  B_new = self.update_B(B=M, s=s, y=y, p=p, g=g, p_prev=p_prev, g_prev=g_prev, state=state, settings=settings)
118
123
  _maybe_lerp_(state, 'B', B_new, beta)
119
124
 
125
+ state['f_prev'] = loss
126
+
120
127
  @torch.no_grad
121
- def apply_tensor(self, tensor, param, grad, state, settings):
128
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
122
129
  step = state.get('step', 0)
123
130
 
124
131
  if settings['scale_second'] and step == 2:
@@ -198,19 +205,15 @@ class SR1(HUpdateStrategy):
198
205
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
199
206
  return sr1_H_(H=H, s=s, y=y, tol=settings['tol'])
200
207
 
201
- # BFGS has defaults - init_scale = "auto" and scale_second = False
202
- # SR1 has defaults - init_scale = 1 and scale_second = True
203
- # basically some methods work better with first and some with second.
204
- # I inherit from BFGS or SR1 to avoid writing all those arguments again
205
208
  # ------------------------------------ DFP ----------------------------------- #
206
209
  def dfp_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
207
210
  sy = torch.dot(s, y)
208
211
  if sy.abs() <= tol: return H
209
212
  term1 = torch.outer(s, s).div_(sy)
210
- denom = torch.dot(y, H @ y) #
211
- if denom.abs() <= tol: return H
213
+ yHy = torch.dot(y, H @ y) #
214
+ if yHy.abs() <= tol: return H
212
215
  num = H @ torch.outer(y, y) @ H
213
- term2 = num.div_(denom)
216
+ term2 = num.div_(yHy)
214
217
  H += term1.sub_(term2)
215
218
  return H
216
219
 
@@ -225,34 +228,35 @@ class DFP(HUpdateStrategy):
225
228
 
226
229
  def broyden_good_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
227
230
  c = H.T @ s
228
- denom = c.dot(y)
229
- if denom.abs() <= tol: return H
231
+ cy = c.dot(y)
232
+ if cy.abs() <= tol: return H
230
233
  num = (H@y).sub_(s).outer(c)
231
- H -= num/denom
234
+ H -= num/cy
232
235
  return H
233
236
 
234
237
  def broyden_bad_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
235
238
  c = y
236
- denom = c.dot(y)
237
- if denom.abs() <= tol: return H
239
+ cy = c.dot(y)
240
+ if cy.abs() <= tol: return H
238
241
  num = (H@y).sub_(s).outer(c)
239
- H -= num/denom
242
+ H -= num/cy
240
243
  return H
241
244
 
242
245
  def greenstadt1_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g_prev: torch.Tensor, tol: float):
243
246
  c = g_prev
244
- denom = c.dot(y)
245
- if denom.abs() <= tol: return H
247
+ cy = c.dot(y)
248
+ if cy.abs() <= tol: return H
246
249
  num = (H@y).sub_(s).outer(c)
247
- H -= num/denom
250
+ H -= num/cy
248
251
  return H
249
252
 
250
253
  def greenstadt2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
251
- c = torch.linalg.multi_dot([H,H,y]) # pylint:disable=not-callable
252
- denom = c.dot(y)
253
- if denom.abs() <= tol: return H
254
- num = (H@y).sub_(s).outer(c)
255
- H -= num/denom
254
+ Hy = H @ y
255
+ c = H @ Hy # pylint:disable=not-callable
256
+ cy = c.dot(y)
257
+ if cy.abs() <= tol: return H
258
+ num = Hy.sub_(s).outer(c)
259
+ H -= num/cy
256
260
  return H
257
261
 
258
262
  class BroydenGood(HUpdateStrategy):
@@ -273,11 +277,7 @@ class Greenstadt2(HUpdateStrategy):
273
277
 
274
278
 
275
279
  def column_updating_H_(H:torch.Tensor, s:torch.Tensor, y:torch.Tensor, tol:float):
276
- n = H.shape[0]
277
-
278
280
  j = y.abs().argmax()
279
- u = torch.zeros(n, device=H.device, dtype=H.dtype)
280
- u[j] = 1.0
281
281
 
282
282
  denom = y[j]
283
283
  if denom.abs() < tol: return H
@@ -297,15 +297,15 @@ def thomas_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor,
297
297
  s_norm = torch.linalg.vector_norm(s) # pylint:disable=not-callable
298
298
  I = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
299
299
  d = (R + I * (s_norm/2)) @ s
300
- denom = d.dot(s)
301
- if denom.abs() <= tol: return H, R
302
- R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(denom)))
300
+ ds = d.dot(s)
301
+ if ds.abs() <= tol: return H, R
302
+ R = (1 + s_norm) * ((I*s_norm).add_(R).sub_(d.outer(d).div_(ds)))
303
303
 
304
304
  c = H.T @ d
305
- denom = c.dot(y)
306
- if denom.abs() <= tol: return H, R
305
+ cy = c.dot(y)
306
+ if cy.abs() <= tol: return H, R
307
307
  num = (H@y).sub_(s).outer(c)
308
- H -= num/denom
308
+ H -= num/cy
309
309
  return H, R
310
310
 
311
311
  class ThomasOptimalMethod(HUpdateStrategy):
@@ -315,6 +315,11 @@ class ThomasOptimalMethod(HUpdateStrategy):
315
315
  H, state['R'] = thomas_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
316
316
  return H
317
317
 
318
+ def _reset_M_(self, M, s, y,inverse, init_scale, state):
319
+ super()._reset_M_(M, s, y, inverse, init_scale, state)
320
+ for st in self.state.values():
321
+ st.pop("R", None)
322
+
318
323
  # ------------------------ powell's symmetric broyden ------------------------ #
319
324
  def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
320
325
  y_Bs = y - B@s
@@ -326,6 +331,7 @@ def psb_B_(B: torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
326
331
  B += term1.sub_(term2)
327
332
  return B
328
333
 
334
+ # I couldn't find formula for H
329
335
  class PSB(HessianUpdateStrategy):
330
336
  def __init__(
331
337
  self,
@@ -358,17 +364,85 @@ class PSB(HessianUpdateStrategy):
358
364
  def update_B(self, B, s, y, p, g, p_prev, g_prev, state, settings):
359
365
  return psb_B_(B=B, s=s, y=y, tol=settings['tol'])
360
366
 
361
- def pearson2_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
367
+
368
+ # Algorithms from Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171
369
+ def pearson_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
370
+ Hy = H@y
371
+ yHy = y.dot(Hy)
372
+ if yHy.abs() <= tol: return H
373
+ num = (s - Hy).outer(Hy)
374
+ H += num.div_(yHy)
375
+ return H
376
+
377
+ class Pearson(HUpdateStrategy):
378
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
379
+
380
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
381
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
382
+ return pearson_H_(H=H, s=s, y=y, tol=settings['tol'])
383
+
384
+ def mccormick_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
362
385
  sy = s.dot(y)
363
386
  if sy.abs() <= tol: return H
364
387
  num = (s - H@y).outer(s)
365
388
  H += num.div_(sy)
366
389
  return H
367
390
 
368
- class Pearson2(HUpdateStrategy):
369
- """finally found a reference in https://www.recotechnologies.com/~beigi/ps/asme-jdsmc-93-2.pdf"""
391
+ class McCormick(HUpdateStrategy):
392
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
393
+
394
+ This is "Algorithm 2", attributed to McCormick in this paper. However for some reason this method is also called Pearson's 2nd method."""
370
395
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
371
- return pearson2_H_(H=H, s=s, y=y, tol=settings['tol'])
396
+ return mccormick_H_(H=H, s=s, y=y, tol=settings['tol'])
397
+
398
+ def projected_newton_raphson_H_(H: torch.Tensor, R:torch.Tensor, s: torch.Tensor, y: torch.Tensor, tol:float):
399
+ Hy = H @ y
400
+ yHy = y.dot(Hy)
401
+ if yHy.abs() < tol: return H, R
402
+ H -= Hy.outer(Hy) / yHy
403
+ R += (s - R@y).outer(Hy) / yHy
404
+ return H, R
405
+
406
+ class ProjectedNewtonRaphson(HessianUpdateStrategy):
407
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
408
+
409
+ Algorithm 7"""
410
+ def __init__(
411
+ self,
412
+ init_scale: float | Literal["auto"] = 'auto',
413
+ tol: float = 1e-10,
414
+ tol_reset: bool = True,
415
+ reset_interval: int | None | Literal['auto'] = 'auto',
416
+ beta: float | None = None,
417
+ update_freq: int = 1,
418
+ scale_first: bool = True,
419
+ scale_second: bool = False,
420
+ concat_params: bool = True,
421
+ inner: Chainable | None = None,
422
+ ):
423
+ super().__init__(
424
+ init_scale=init_scale,
425
+ tol=tol,
426
+ tol_reset=tol_reset,
427
+ reset_interval=reset_interval,
428
+ beta=beta,
429
+ update_freq=update_freq,
430
+ scale_first=scale_first,
431
+ scale_second=scale_second,
432
+ concat_params=concat_params,
433
+ inverse=True,
434
+ inner=inner,
435
+ )
436
+
437
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
438
+ if 'R' not in state: state['R'] = torch.eye(H.size(-1), device=H.device, dtype=H.dtype)
439
+ H, R = projected_newton_raphson_H_(H=H, R=state['R'], s=s, y=y, tol=settings['tol'])
440
+ state["R"] = R
441
+ return H
442
+
443
+ def _reset_M_(self, M, s, y, inverse, init_scale, state):
444
+ assert inverse
445
+ M.copy_(state["R"])
372
446
 
373
447
  # Oren, S. S., & Spedicato, E. (1976). Optimal conditioning of self-scaling variable metric algorithms. Mathematical programming, 10(1), 70-90.
374
448
  def ssvm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, g:torch.Tensor, switch: tuple[float,float] | Literal[1,2,3,4], tol: float):
@@ -473,4 +547,137 @@ class SSVM(HessianUpdateStrategy):
473
547
  )
474
548
 
475
549
  def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
476
- return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
550
+ return ssvm_H_(H=H, s=s, y=y, g=g, switch=settings['switch'], tol=settings['tol'])
551
+
552
+ # HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
553
+ def hoshino_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
554
+ Hy = H@y
555
+ ys = y.dot(s)
556
+ if ys.abs() <= tol: return H
557
+ yHy = y.dot(Hy)
558
+ denom = ys + yHy
559
+ if denom.abs() <= tol: return H
560
+
561
+ term1 = 1/denom
562
+ term2 = s.outer(s).mul_(1 + ((2 * yHy) / ys))
563
+ term3 = s.outer(y) @ H
564
+ term4 = Hy.outer(s)
565
+ term5 = Hy.outer(y) @ H
566
+
567
+ inner_term = term2 - term3 - term4 - term5
568
+ H += inner_term.mul_(term1)
569
+ return H
570
+
571
+ def gradient_correction(g: TensorList, s: TensorList, y: TensorList):
572
+ sy = s.dot(y)
573
+ if sy.abs() < torch.finfo(g[0].dtype).eps: return g
574
+ return g - (y * (s.dot(g) / sy))
575
+
576
+
577
+ class GradientCorrection(Transform):
578
+ """estimates gradient at minima along search direction assuming function is quadratic as proposed in HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394
579
+
580
+ This can useful as inner module for second order methods."""
581
+ def __init__(self):
582
+ super().__init__(None, uses_grad=False)
583
+
584
+ def apply(self, tensors, params, grads, loss, states, settings):
585
+ if 'p_prev' not in states[0]:
586
+ p_prev = unpack_states(states, tensors, 'p_prev', init=params)
587
+ g_prev = unpack_states(states, tensors, 'g_prev', init=tensors)
588
+ return tensors
589
+
590
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
591
+ g_hat = gradient_correction(TensorList(tensors), params-p_prev, tensors-g_prev)
592
+
593
+ p_prev.copy_(params)
594
+ g_prev.copy_(tensors)
595
+ return g_hat
596
+
597
+ class Horisho(HUpdateStrategy):
598
+ """HOSHINO, S. (1972). A Formulation of Variable Metric Methods. IMA Journal of Applied Mathematics, 10(3), 394–403. doi:10.1093/imamat/10.3.394"""
599
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
600
+ return hoshino_H_(H=H, s=s, y=y, tol=settings['tol'])
601
+
602
+ # Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317
603
+ def fletcher_vmm_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
604
+ sy = s.dot(y)
605
+ if sy.abs() < tol: return H
606
+ Hy = H @ y
607
+
608
+ term1 = (s.outer(y) @ H).div_(sy)
609
+ term2 = (Hy.outer(s)).div_(sy)
610
+ term3 = 1 + (y.dot(Hy) / sy)
611
+ term4 = s.outer(s).div_(sy)
612
+
613
+ H -= (term1 + term2 - term4.mul_(term3))
614
+ return H
615
+
616
+ class FletcherVMM(HUpdateStrategy):
617
+ """Fletcher, R. (1970). A new approach to variable metric algorithms. The Computer Journal, 13(3), 317–322. doi:10.1093/comjnl/13.3.317"""
618
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
619
+ return fletcher_vmm_H_(H=H, s=s, y=y, tol=settings['tol'])
620
+
621
+
622
+ # Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U.
623
+ def new_ssm1(H: torch.Tensor, s: torch.Tensor, y: torch.Tensor, f, f_prev, tol: float, type:int):
624
+ sy = s.dot(y)
625
+ if sy < tol: return H
626
+
627
+ term1 = (H @ y.outer(s) + s.outer(y) @ H) / sy
628
+
629
+ if type == 1:
630
+ pba = (2*sy + 2*(f-f_prev)) / sy
631
+
632
+ elif type == 2:
633
+ pba = (f_prev - f + 1/(2*sy)) / sy
634
+
635
+ else:
636
+ raise RuntimeError(type)
637
+
638
+ term3 = 1/pba + y.dot(H@y) / sy
639
+ term4 = s.outer(s) / sy
640
+
641
+ H.sub_(term1)
642
+ H.add_(term4.mul_(term3))
643
+ return H
644
+
645
+
646
+ class NewSSM(HessianUpdateStrategy):
647
+ """Self-scaling method, requires a line search.
648
+
649
+ Moghrabi, I. A., Hassan, B. A., & Askar, A. (2022). New self-scaling quasi-newton methods for unconstrained optimization. Int. J. Math. Comput. Sci., 17, 1061U."""
650
+ def __init__(
651
+ self,
652
+ type: Literal[1, 2] = 1,
653
+ init_scale: float | Literal["auto"] = "auto",
654
+ tol: float = 1e-10,
655
+ tol_reset: bool = True,
656
+ reset_interval: int | None = None,
657
+ beta: float | None = None,
658
+ update_freq: int = 1,
659
+ scale_first: bool = True,
660
+ scale_second: bool = False,
661
+ concat_params: bool = True,
662
+ inner: Chainable | None = None,
663
+ ):
664
+ super().__init__(
665
+ defaults=dict(type=type),
666
+ init_scale=init_scale,
667
+ tol=tol,
668
+ tol_reset=tol_reset,
669
+ reset_interval=reset_interval,
670
+ beta=beta,
671
+ update_freq=update_freq,
672
+ scale_first=scale_first,
673
+ scale_second=scale_second,
674
+ concat_params=concat_params,
675
+ inverse=True,
676
+ inner=inner,
677
+ )
678
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, settings):
679
+ f = state['f']
680
+ f_prev = state['f_prev']
681
+ return new_ssm1(H=H, s=s, y=y, f=f, f_prev=f_prev, type=settings['type'], tol=settings['tol'])
682
+
683
+
@@ -5,7 +5,7 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, apply
8
+ from ...core import Chainable, Module, apply_transform
9
9
  from ...utils import TensorList, vec_to_tensors
10
10
  from ...utils.derivatives import (
11
11
  hessian_list_to_mat,
@@ -18,9 +18,12 @@ from ...utils.derivatives import (
18
18
 
19
19
 
20
20
  def lu_solve(H: torch.Tensor, g: torch.Tensor):
21
- x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
22
- if info == 0: return x
23
- return None
21
+ try:
22
+ x, info = torch.linalg.solve_ex(H, g) # pylint:disable=not-callable
23
+ if info == 0: return x
24
+ return None
25
+ except RuntimeError:
26
+ return None
24
27
 
25
28
  def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
26
29
  x, info = torch.linalg.cholesky_ex(H) # pylint:disable=not-callable
@@ -32,10 +35,15 @@ def cholesky_solve(H: torch.Tensor, g: torch.Tensor):
32
35
  def least_squares_solve(H: torch.Tensor, g: torch.Tensor):
33
36
  return torch.linalg.lstsq(H, g)[0] # pylint:disable=not-callable
34
37
 
35
- def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None):
38
+ def eigh_solve(H: torch.Tensor, g: torch.Tensor, tfm: Callable | None, search_negative: bool):
36
39
  try:
37
40
  L, Q = torch.linalg.eigh(H) # pylint:disable=not-callable
38
41
  if tfm is not None: L = tfm(L)
42
+ if search_negative and L[0] < 0:
43
+ d = Q[0]
44
+ # use eigvec or -eigvec depending on if it points in same direction as gradient
45
+ return g.dot(d).sign() * d
46
+
39
47
  L.reciprocal_()
40
48
  return torch.linalg.multi_dot([Q * L.unsqueeze(-2), Q.mH, g]) # pylint:disable=not-callable
41
49
  except torch.linalg.LinAlgError:
@@ -56,6 +64,8 @@ class Newton(Module):
56
64
  Args:
57
65
  reg (float, optional): tikhonov regularizer value. Defaults to 1e-6.
58
66
  eig_reg (bool, optional): whether to use largest negative eigenvalue as regularizer. Defaults to False.
67
+ search_negative (bool, Optional):
68
+ if True, whenever a negative eigenvalue is detected, the direction is taken along an eigenvector corresponding to a negative eigenvalue.
59
69
  hessian_method (str):
60
70
  how to calculate hessian. Defaults to "autograd".
61
71
  vectorize (bool, optional):
@@ -75,27 +85,29 @@ class Newton(Module):
75
85
  self,
76
86
  reg: float = 1e-6,
77
87
  eig_reg: bool = False,
88
+ search_negative: bool = False,
78
89
  hessian_method: Literal["autograd", "func", "autograd.functional"] = "autograd",
79
90
  vectorize: bool = True,
80
91
  inner: Chainable | None = None,
81
92
  H_tfm: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, bool]] | None = None,
82
93
  eigval_tfm: Callable[[torch.Tensor], torch.Tensor] | None = None,
83
94
  ):
84
- defaults = dict(reg=reg, eig_reg=eig_reg, abs=abs,hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm)
95
+ defaults = dict(reg=reg, eig_reg=eig_reg, hessian_method=hessian_method, vectorize=vectorize, H_tfm=H_tfm, eigval_tfm=eigval_tfm, search_negative=search_negative)
85
96
  super().__init__(defaults)
86
97
 
87
98
  if inner is not None:
88
99
  self.set_child('inner', inner)
89
100
 
90
101
  @torch.no_grad
91
- def step(self, vars):
92
- params = TensorList(vars.params)
93
- closure = vars.closure
102
+ def step(self, var):
103
+ params = TensorList(var.params)
104
+ closure = var.closure
94
105
  if closure is None: raise RuntimeError('NewtonCG requires closure')
95
106
 
96
107
  settings = self.settings[params[0]]
97
108
  reg = settings['reg']
98
109
  eig_reg = settings['eig_reg']
110
+ search_negative = settings['search_negative']
99
111
  hessian_method = settings['hessian_method']
100
112
  vectorize = settings['vectorize']
101
113
  H_tfm = settings['H_tfm']
@@ -104,16 +116,16 @@ class Newton(Module):
104
116
  # ------------------------ calculate grad and hessian ------------------------ #
105
117
  if hessian_method == 'autograd':
106
118
  with torch.enable_grad():
107
- loss = vars.loss = vars.loss_approx = closure(False)
119
+ loss = var.loss = var.loss_approx = closure(False)
108
120
  g_list, H_list = jacobian_and_hessian_wrt([loss], params, batched=vectorize)
109
121
  g_list = [t[0] for t in g_list] # remove leading dim from loss
110
- vars.grad = g_list
122
+ var.grad = g_list
111
123
  H = hessian_list_to_mat(H_list)
112
124
 
113
125
  elif hessian_method in ('func', 'autograd.functional'):
114
126
  strat = 'forward-mode' if vectorize else 'reverse-mode'
115
127
  with torch.enable_grad():
116
- g_list = vars.get_grad(retain_graph=True)
128
+ g_list = var.get_grad(retain_graph=True)
117
129
  H: torch.Tensor = hessian_mat(partial(closure, backward=False), params,
118
130
  method=hessian_method, vectorize=vectorize, outer_jacobian_strategy=strat) # pyright:ignore[reportAssignmentType]
119
131
 
@@ -121,10 +133,10 @@ class Newton(Module):
121
133
  raise ValueError(hessian_method)
122
134
 
123
135
  # -------------------------------- inner step -------------------------------- #
124
- update = vars.get_update()
136
+ update = var.get_update()
125
137
  if 'inner' in self.children:
126
- update = apply(self.children['inner'], update, params=params, grads=list(g_list), vars=vars)
127
- g = torch.cat([t.view(-1) for t in update])
138
+ update = apply_transform(self.children['inner'], update, params=params, grads=list(g_list), var=var)
139
+ g = torch.cat([t.ravel() for t in update])
128
140
 
129
141
  # ------------------------------- regulazition ------------------------------- #
130
142
  if eig_reg: H = eig_tikhonov_(H, reg)
@@ -134,14 +146,14 @@ class Newton(Module):
134
146
  update = None
135
147
  if H_tfm is not None:
136
148
  H, is_inv = H_tfm(H, g)
137
- if is_inv: update = H
149
+ if is_inv: update = H @ g
138
150
 
139
- if eigval_tfm is not None:
140
- update = eigh_solve(H, g, eigval_tfm)
151
+ if search_negative or (eigval_tfm is not None):
152
+ update = eigh_solve(H, g, eigval_tfm, search_negative=search_negative)
141
153
 
142
154
  if update is None: update = cholesky_solve(H, g)
143
155
  if update is None: update = lu_solve(H, g)
144
156
  if update is None: update = least_squares_solve(H, g)
145
157
 
146
- vars.update = vec_to_tensors(update, params)
147
- return vars
158
+ var.update = vec_to_tensors(update, params)
159
+ return var
@@ -6,14 +6,14 @@ import torch
6
6
  from ...utils import TensorList, as_tensorlist, generic_zeros_like, generic_vector_norm, generic_numel
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
- from ...core import Chainable, apply, Module
9
+ from ...core import Chainable, apply_transform, Module
10
10
  from ...utils.linalg.solve import cg
11
11
 
12
12
  class NewtonCG(Module):
13
13
  def __init__(
14
14
  self,
15
15
  maxiter=None,
16
- tol=1e-3,
16
+ tol=1e-4,
17
17
  reg: float = 1e-8,
18
18
  hvp_method: Literal["forward", "central", "autograd"] = "forward",
19
19
  h=1e-3,
@@ -27,9 +27,9 @@ class NewtonCG(Module):
27
27
  self.set_child('inner', inner)
28
28
 
29
29
  @torch.no_grad
30
- def step(self, vars):
31
- params = TensorList(vars.params)
32
- closure = vars.closure
30
+ def step(self, var):
31
+ params = TensorList(var.params)
32
+ closure = var.closure
33
33
  if closure is None: raise RuntimeError('NewtonCG requires closure')
34
34
 
35
35
  settings = self.settings[params[0]]
@@ -42,7 +42,7 @@ class NewtonCG(Module):
42
42
 
43
43
  # ---------------------- Hessian vector product function --------------------- #
44
44
  if hvp_method == 'autograd':
45
- grad = vars.get_grad(create_graph=True)
45
+ grad = var.get_grad(create_graph=True)
46
46
 
47
47
  def H_mm(x):
48
48
  with torch.enable_grad():
@@ -51,7 +51,7 @@ class NewtonCG(Module):
51
51
  else:
52
52
 
53
53
  with torch.enable_grad():
54
- grad = vars.get_grad()
54
+ grad = var.get_grad()
55
55
 
56
56
  if hvp_method == 'forward':
57
57
  def H_mm(x):
@@ -66,19 +66,20 @@ class NewtonCG(Module):
66
66
 
67
67
 
68
68
  # -------------------------------- inner step -------------------------------- #
69
- b = vars.get_update()
69
+ b = var.get_update()
70
70
  if 'inner' in self.children:
71
- b = as_tensorlist(apply(self.children['inner'], b, params=params, grads=grad, vars=vars))
71
+ b = as_tensorlist(apply_transform(self.children['inner'], b, params=params, grads=grad, var=var))
72
72
 
73
73
  # ---------------------------------- run cg ---------------------------------- #
74
74
  x0 = None
75
- if warm_start: x0 = self.get_state('prev_x', params=params, cls=TensorList) # initialized to 0 which is default anyway
75
+ if warm_start: x0 = self.get_state(params, 'prev_x', cls=TensorList) # initialized to 0 which is default anyway
76
+
76
77
  x = cg(A_mm=H_mm, b=as_tensorlist(b), x0_=x0, tol=tol, maxiter=maxiter, reg=reg)
77
78
  if warm_start:
78
79
  assert x0 is not None
79
80
  x0.copy_(x)
80
81
 
81
- vars.update = x
82
- return vars
82
+ var.update = x
83
+ return var
83
84
 
84
85