torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -5,10 +5,42 @@ import torch
5
5
 
6
6
  from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
7
7
  from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+ from .quasi_newton import _safe_clip, HessianUpdateStrategy
8
9
 
9
10
 
10
11
  class ConguateGradientBase(Transform, ABC):
11
- """all CGs are the same except beta calculation"""
12
+ """Base class for conjugate gradient methods. The only difference between them is how beta is calculated.
13
+
14
+ This is an abstract class, to use it, subclass it and override `get_beta`.
15
+
16
+
17
+ Args:
18
+ defaults (dict | None, optional): dictionary of settings defaults. Defaults to None.
19
+ clip_beta (bool, optional): whether to clip beta to be no less than 0. Defaults to False.
20
+ reset_interval (int | None | Literal["auto"], optional):
21
+ interval between resetting the search direction.
22
+ "auto" means number of dimensions + 1, None means no reset. Defaults to None.
23
+ inner (Chainable | None, optional): previous direction is added to the output of this module. Defaults to None.
24
+
25
+ Example:
26
+
27
+ .. code-block:: python
28
+
29
+ class PolakRibiere(ConguateGradientBase):
30
+ def __init__(
31
+ self,
32
+ clip_beta=True,
33
+ reset_interval: int | None = None,
34
+ inner: Chainable | None = None
35
+ ):
36
+ super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
37
+
38
+ def get_beta(self, p, g, prev_g, prev_d):
39
+ denom = prev_g.dot(prev_g)
40
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
41
+ return g.dot(g - prev_g) / denom
42
+
43
+ """
12
44
  def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
13
45
  if defaults is None: defaults = {}
14
46
  defaults['reset_interval'] = reset_interval
@@ -18,6 +50,15 @@ class ConguateGradientBase(Transform, ABC):
18
50
  if inner is not None:
19
51
  self.set_child('inner', inner)
20
52
 
53
+ def reset(self):
54
+ super().reset()
55
+
56
+ def reset_for_online(self):
57
+ super().reset_for_online()
58
+ self.clear_state_keys('prev_grad')
59
+ self.global_state.pop('stage', None)
60
+ self.global_state['step'] = self.global_state.get('step', 1) - 1
61
+
21
62
  def initialize(self, p: TensorList, g: TensorList):
22
63
  """runs on first step when prev_grads and prev_dir are not available"""
23
64
 
@@ -26,39 +67,55 @@ class ConguateGradientBase(Transform, ABC):
26
67
  """returns beta"""
27
68
 
28
69
  @torch.no_grad
29
- def apply(self, tensors, params, grads, loss, states, settings):
70
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
30
71
  tensors = as_tensorlist(tensors)
31
72
  params = as_tensorlist(params)
32
73
 
33
- step = self.global_state.get('step', 0)
34
- prev_dir, prev_grads = unpack_states(states, tensors, 'prev_dir', 'prev_grad', cls=TensorList)
74
+ step = self.global_state.get('step', 0) + 1
75
+ self.global_state['step'] = step
35
76
 
36
77
  # initialize on first step
37
- if step == 0:
78
+ if self.global_state.get('stage', 0) == 0:
79
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
80
+ d_prev.copy_(tensors)
81
+ g_prev.copy_(tensors)
38
82
  self.initialize(params, tensors)
39
- prev_dir.copy_(tensors)
40
- prev_grads.copy_(tensors)
41
- self.global_state['step'] = step + 1
83
+ self.global_state['stage'] = 1
84
+
85
+ else:
86
+ # if `update_tensors` was called multiple times before `apply_tensors`,
87
+ # stage becomes 2
88
+ self.global_state['stage'] = 2
89
+
90
+ @torch.no_grad
91
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
+ tensors = as_tensorlist(tensors)
93
+ step = self.global_state['step']
94
+
95
+ if 'inner' in self.children:
96
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
97
+
98
+ assert self.global_state['stage'] != 0
99
+ if self.global_state['stage'] == 1:
100
+ self.global_state['stage'] = 2
42
101
  return tensors
43
102
 
103
+ params = as_tensorlist(params)
104
+ g_prev, d_prev = unpack_states(states, tensors, 'g_prev', 'd_prev', cls=TensorList)
105
+
44
106
  # get beta
45
- beta = self.get_beta(params, tensors, prev_grads, prev_dir)
107
+ beta = self.get_beta(params, tensors, g_prev, d_prev)
46
108
  if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
47
- prev_grads.copy_(tensors)
48
109
 
49
110
  # inner step
50
- if 'inner' in self.children:
51
- tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
52
-
53
111
  # calculate new direction with beta
54
- dir = tensors.add_(prev_dir.mul_(beta))
55
- prev_dir.copy_(dir)
112
+ dir = tensors.add_(d_prev.mul_(beta))
113
+ d_prev.copy_(dir)
56
114
 
57
115
  # resetting
58
- self.global_state['step'] = step + 1
59
116
  reset_interval = settings[0]['reset_interval']
60
117
  if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
61
- if reset_interval is not None and (step+1) % reset_interval == 0:
118
+ if reset_interval is not None and step % reset_interval == 0:
62
119
  self.reset()
63
120
 
64
121
  return dir
@@ -70,7 +127,11 @@ def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
70
127
  return g.dot(g - prev_g) / denom
71
128
 
72
129
  class PolakRibiere(ConguateGradientBase):
73
- """Polak-Ribière-Polyak nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this."""
130
+ """Polak-Ribière-Polyak nonlinear conjugate gradient method.
131
+
132
+ .. note::
133
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
134
+ """
74
135
  def __init__(self, clip_beta=True, reset_interval: int | None = None, inner: Chainable | None = None):
75
136
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
76
137
 
@@ -83,7 +144,11 @@ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
83
144
  return gg / prev_gg
84
145
 
85
146
  class FletcherReeves(ConguateGradientBase):
86
- """Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
147
+ """Fletcher–Reeves nonlinear conjugate gradient method.
148
+
149
+ .. note::
150
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
151
+ """
87
152
  def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
88
153
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
89
154
 
@@ -105,7 +170,11 @@ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
105
170
 
106
171
 
107
172
  class HestenesStiefel(ConguateGradientBase):
108
- """Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
173
+ """Hestenes–Stiefel nonlinear conjugate gradient method.
174
+
175
+ .. note::
176
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
177
+ """
109
178
  def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
110
179
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
111
180
 
@@ -120,7 +189,11 @@ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
120
189
  return (g.dot(g) / denom).neg()
121
190
 
122
191
  class DaiYuan(ConguateGradientBase):
123
- """Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
192
+ """Dai–Yuan nonlinear conjugate gradient method.
193
+
194
+ .. note::
195
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this. Although Dai–Yuan formula provides an automatic step size scaling so it is technically possible to omit line search and instead use a small step size.
196
+ """
124
197
  def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
125
198
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
126
199
 
@@ -135,7 +208,11 @@ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
135
208
  return g.dot(g - prev_g) / denom
136
209
 
137
210
  class LiuStorey(ConguateGradientBase):
138
- """Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
211
+ """Liu-Storey nonlinear conjugate gradient method.
212
+
213
+ .. note::
214
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
215
+ """
139
216
  def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
140
217
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
141
218
 
@@ -144,7 +221,11 @@ class LiuStorey(ConguateGradientBase):
144
221
 
145
222
  # ----------------------------- Conjugate Descent ---------------------------- #
146
223
  class ConjugateDescent(Transform):
147
- """Conjugate Descent (CD). This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
224
+ """Conjugate Descent (CD).
225
+
226
+ .. note::
227
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
228
+ """
148
229
  def __init__(self, inner: Chainable | None = None):
149
230
  super().__init__(defaults={}, uses_grad=False)
150
231
 
@@ -153,7 +234,7 @@ class ConjugateDescent(Transform):
153
234
 
154
235
 
155
236
  @torch.no_grad
156
- def apply(self, tensors, params, grads, loss, states, settings):
237
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
157
238
  g = as_tensorlist(tensors)
158
239
 
159
240
  prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
@@ -188,7 +269,10 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
188
269
 
189
270
  class HagerZhang(ConguateGradientBase):
190
271
  """Hager-Zhang nonlinear conjugate gradient method,
191
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
272
+
273
+ .. note::
274
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
275
+ """
192
276
  def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
193
277
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
194
278
 
@@ -212,7 +296,10 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
212
296
 
213
297
  class HybridHS_DY(ConguateGradientBase):
214
298
  """HS-DY hybrid conjugate gradient method.
215
- This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
299
+
300
+ .. note::
301
+ - This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
302
+ """
216
303
  def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
217
304
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
218
305
 
@@ -220,49 +307,63 @@ class HybridHS_DY(ConguateGradientBase):
220
307
  return hs_dy_beta(g, prev_d, prev_g)
221
308
 
222
309
 
223
- def projected_gradient_(H:torch.Tensor, y:torch.Tensor, tol: float):
310
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor):
224
311
  Hy = H @ y
225
- denom = y.dot(Hy)
226
- if denom.abs() < tol: return H
227
- H -= (H @ y.outer(y) @ H) / denom
312
+ yHy = _safe_clip(y.dot(Hy))
313
+ H -= (Hy.outer(y) @ H) / yHy
228
314
  return H
229
315
 
230
- class ProjectedGradientMethod(TensorwiseTransform):
231
- """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
316
+ class ProjectedGradientMethod(HessianUpdateStrategy): # this doesn't maintain hessian
317
+ """Projected gradient method.
318
+
319
+ .. note::
320
+ This method uses N^2 memory.
321
+
322
+ .. note::
323
+ This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe(c2=0.1)` after this.
324
+
325
+ .. note::
326
+ This is not the same as projected gradient descent.
327
+
328
+ Reference:
329
+ Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
232
330
 
233
- (This is not the same as projected gradient descent)
234
331
  """
235
332
 
236
333
  def __init__(
237
334
  self,
238
- tol: float = 1e-10,
239
- reset_interval: int | None = None,
335
+ init_scale: float | Literal["auto"] = 1,
336
+ tol: float = 1e-8,
337
+ ptol: float | None = 1e-10,
338
+ ptol_reset: bool = False,
339
+ gtol: float | None = 1e-10,
340
+ reset_interval: int | None | Literal['auto'] = 'auto',
341
+ beta: float | None = None,
240
342
  update_freq: int = 1,
241
343
  scale_first: bool = False,
344
+ scale_second: bool = False,
242
345
  concat_params: bool = True,
346
+ # inverse: bool = True,
243
347
  inner: Chainable | None = None,
244
348
  ):
245
- defaults = dict(reset_interval=reset_interval, tol=tol)
246
- super().__init__(defaults, uses_grad=False, scale_first=scale_first, concat_params=concat_params, update_freq=update_freq, inner=inner)
247
-
248
- def update_tensor(self, tensor, param, grad, loss, state, settings):
249
- step = state.get('step', 0)
250
- state['step'] = step + 1
251
- reset_interval = settings['reset_interval']
252
- if reset_interval is None: reset_interval = tensor.numel() + 1 # as recommended
253
-
254
- if ("H" not in state) or (step % reset_interval == 0):
255
- state["H"] = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype)
256
- state['g_prev'] = tensor.clone()
257
- return
258
-
259
- H = state['H']
260
- g_prev = state['g_prev']
261
- state['g_prev'] = tensor.clone()
262
- y = (tensor - g_prev).ravel()
263
-
264
- projected_gradient_(H, y, settings['tol'])
265
-
266
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
267
- H = state['H']
268
- return (H @ tensor.view(-1)).view_as(tensor)
349
+ super().__init__(
350
+ defaults=None,
351
+ init_scale=init_scale,
352
+ tol=tol,
353
+ ptol=ptol,
354
+ ptol_reset=ptol_reset,
355
+ gtol=gtol,
356
+ reset_interval=reset_interval,
357
+ beta=beta,
358
+ update_freq=update_freq,
359
+ scale_first=scale_first,
360
+ scale_second=scale_second,
361
+ concat_params=concat_params,
362
+ inverse=True,
363
+ inner=inner,
364
+ )
365
+
366
+
367
+
368
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
369
+ return projected_gradient_(H=H, y=y)
@@ -0,0 +1,163 @@
1
+ from collections.abc import Callable
2
+
3
+ import torch
4
+
5
+ from .quasi_newton import (
6
+ HessianUpdateStrategy,
7
+ _HessianUpdateStrategyDefaults,
8
+ _InverseHessianUpdateStrategyDefaults,
9
+ _safe_clip,
10
+ )
11
+
12
+
13
+ def _diag_Bv(self: HessianUpdateStrategy):
14
+ B, is_inverse = self.get_B()
15
+
16
+ if is_inverse:
17
+ H=B
18
+ def Hxv(v): return v/H
19
+ return Hxv
20
+
21
+ def Bv(v): return B*v
22
+ return Bv
23
+
24
+ def _diag_Hv(self: HessianUpdateStrategy):
25
+ H, is_inverse = self.get_H()
26
+
27
+ if is_inverse:
28
+ B=H
29
+ def Bxv(v): return v/B
30
+ return Bxv
31
+
32
+ def Hv(v): return H*v
33
+ return Hv
34
+
35
+ def diagonal_bfgs_H_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol: float):
36
+ sy = s.dot(y)
37
+ if sy < tol: return H
38
+
39
+ sy_sq = _safe_clip(sy**2)
40
+
41
+ num1 = (sy + (y * H * y)) * s*s
42
+ term1 = num1.div_(sy_sq)
43
+ num2 = (H * y * s).add_(s * y * H)
44
+ term2 = num2.div_(sy)
45
+ H += term1.sub_(term2)
46
+ return H
47
+
48
+ class DiagonalBFGS(_InverseHessianUpdateStrategyDefaults):
49
+ """Diagonal BFGS. This is simply BFGS with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
50
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
51
+ return diagonal_bfgs_H_(H=H, s=s, y=y, tol=setting['tol'])
52
+
53
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
54
+ def make_Bv(self): return _diag_Bv(self)
55
+ def make_Hv(self): return _diag_Hv(self)
56
+
57
+ def diagonal_sr1_(H:torch.Tensor, s: torch.Tensor, y:torch.Tensor, tol:float):
58
+ z = s - H*y
59
+ denom = z.dot(y)
60
+
61
+ z_norm = torch.linalg.norm(z) # pylint:disable=not-callable
62
+ y_norm = torch.linalg.norm(y) # pylint:disable=not-callable
63
+
64
+ # if y_norm*z_norm < tol: return H
65
+
66
+ # check as in Nocedal, Wright. “Numerical optimization” 2nd p.146
67
+ if denom.abs() <= tol * y_norm * z_norm: return H # pylint:disable=not-callable
68
+ H += (z*z).div_(_safe_clip(denom))
69
+ return H
70
+ class DiagonalSR1(_InverseHessianUpdateStrategyDefaults):
71
+ """Diagonal SR1. This is simply SR1 with only the diagonal being updated and used. It doesn't satisfy the secant equation but may still be useful."""
72
+ def update_H(self, H, s, y, p, g, p_prev, g_prev, state, setting):
73
+ return diagonal_sr1_(H=H, s=s, y=y, tol=setting['tol'])
74
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
75
+ return diagonal_sr1_(H=B, s=y, y=s, tol=setting['tol'])
76
+
77
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
78
+ def make_Bv(self): return _diag_Bv(self)
79
+ def make_Hv(self): return _diag_Hv(self)
80
+
81
+
82
+
83
+ # Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
84
+ def diagonal_qc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
85
+ denom = _safe_clip((s**4).sum())
86
+ num = s.dot(y) - (s*B).dot(s)
87
+ B += s**2 * (num/denom)
88
+ return B
89
+
90
+ class DiagonalQuasiCauchi(_HessianUpdateStrategyDefaults):
91
+ """Diagonal quasi-cauchi method.
92
+
93
+ Reference:
94
+ Zhu M., Nazareth J. L., Wolkowicz H. The quasi-Cauchy relation and diagonal updating //SIAM Journal on Optimization. – 1999. – Т. 9. – №. 4. – С. 1192-1204.
95
+ """
96
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
97
+ return diagonal_qc_B_(B=B, s=s, y=y)
98
+
99
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
100
+ def make_Bv(self): return _diag_Bv(self)
101
+ def make_Hv(self): return _diag_Hv(self)
102
+
103
+ # Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
104
+ def diagonal_wqc_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
105
+ E_sq = s**2 * B**2
106
+ denom = _safe_clip((s*E_sq).dot(s))
107
+ num = s.dot(y) - (s*B).dot(s)
108
+ B += E_sq * (num/denom)
109
+ return B
110
+
111
+ class DiagonalWeightedQuasiCauchi(_HessianUpdateStrategyDefaults):
112
+ """Diagonal quasi-cauchi method.
113
+
114
+ Reference:
115
+ Leong, Wah June, Sharareh Enshaei, and Sie Long Kek. "Diagonal quasi-Newton methods via least change updating principle with weighted Frobenius norm." Numerical Algorithms 86 (2021): 1225-1241.
116
+ """
117
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
118
+ return diagonal_wqc_B_(B=B, s=s, y=y)
119
+
120
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
121
+ def make_Bv(self): return _diag_Bv(self)
122
+ def make_Hv(self): return _diag_Hv(self)
123
+
124
+
125
+ # Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
126
+ def dnrtr_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
127
+ denom = _safe_clip((s**4).sum())
128
+ num = s.dot(y) + s.dot(s) - (s*B).dot(s)
129
+ B += s**2 * (num/denom) - 1
130
+ return B
131
+
132
+ class DNRTR(_HessianUpdateStrategyDefaults):
133
+ """Diagonal quasi-newton method.
134
+
135
+ Reference:
136
+ Andrei, Neculai. "A diagonal quasi-Newton updating method for unconstrained optimization." Numerical Algorithms 81.2 (2019): 575-590.
137
+ """
138
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
139
+ return diagonal_wqc_B_(B=B, s=s, y=y)
140
+
141
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
142
+ def make_Bv(self): return _diag_Bv(self)
143
+ def make_Hv(self): return _diag_Hv(self)
144
+
145
+ # Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
146
+ def new_dqn_B_(B:torch.Tensor, s: torch.Tensor, y:torch.Tensor):
147
+ denom = _safe_clip((s**4).sum())
148
+ num = s.dot(y)
149
+ B += s**2 * (num/denom)
150
+ return B
151
+
152
+ class NewDQN(_HessianUpdateStrategyDefaults):
153
+ """Diagonal quasi-newton method.
154
+
155
+ Reference:
156
+ Nosrati, Mahsa, and Keyvan Amini. "A new diagonal quasi-Newton algorithm for unconstrained optimization problems." Applications of Mathematics 69.4 (2024): 501-512.
157
+ """
158
+ def update_B(self, B, s, y, p, g, p_prev, g_prev, state, setting):
159
+ return new_dqn_B_(B=B, s=s, y=y)
160
+
161
+ def _init_M(self, size:int, device, dtype, is_inverse:bool): return torch.ones(size, device=device, dtype=dtype)
162
+ def make_Bv(self): return _diag_Bv(self)
163
+ def make_Hv(self): return _diag_Hv(self)