torchzero 0.3.10__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. tests/test_identical.py +2 -3
  2. tests/test_opts.py +140 -100
  3. tests/test_tensorlist.py +8 -7
  4. tests/test_vars.py +1 -0
  5. torchzero/__init__.py +1 -1
  6. torchzero/core/__init__.py +2 -2
  7. torchzero/core/module.py +335 -50
  8. torchzero/core/reformulation.py +65 -0
  9. torchzero/core/transform.py +197 -70
  10. torchzero/modules/__init__.py +13 -4
  11. torchzero/modules/adaptive/__init__.py +30 -0
  12. torchzero/modules/adaptive/adagrad.py +356 -0
  13. torchzero/modules/adaptive/adahessian.py +224 -0
  14. torchzero/modules/{optimizers → adaptive}/adam.py +6 -8
  15. torchzero/modules/adaptive/adan.py +96 -0
  16. torchzero/modules/adaptive/adaptive_heavyball.py +54 -0
  17. torchzero/modules/adaptive/aegd.py +54 -0
  18. torchzero/modules/adaptive/esgd.py +171 -0
  19. torchzero/modules/{optimizers → adaptive}/lion.py +1 -1
  20. torchzero/modules/{experimental/spectral.py → adaptive/lmadagrad.py} +94 -71
  21. torchzero/modules/adaptive/mars.py +79 -0
  22. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  23. torchzero/modules/adaptive/msam.py +188 -0
  24. torchzero/modules/{optimizers → adaptive}/muon.py +29 -5
  25. torchzero/modules/adaptive/natural_gradient.py +175 -0
  26. torchzero/modules/{optimizers → adaptive}/orthograd.py +1 -1
  27. torchzero/modules/{optimizers → adaptive}/rmsprop.py +7 -4
  28. torchzero/modules/{optimizers → adaptive}/rprop.py +42 -10
  29. torchzero/modules/adaptive/sam.py +163 -0
  30. torchzero/modules/{optimizers → adaptive}/shampoo.py +47 -9
  31. torchzero/modules/{optimizers → adaptive}/soap.py +52 -65
  32. torchzero/modules/adaptive/sophia_h.py +185 -0
  33. torchzero/modules/clipping/clipping.py +115 -25
  34. torchzero/modules/clipping/ema_clipping.py +31 -17
  35. torchzero/modules/clipping/growth_clipping.py +8 -7
  36. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  37. torchzero/modules/conjugate_gradient/cg.py +355 -0
  38. torchzero/modules/experimental/__init__.py +13 -19
  39. torchzero/modules/{projections → experimental}/dct.py +11 -11
  40. torchzero/modules/{projections → experimental}/fft.py +10 -10
  41. torchzero/modules/experimental/gradmin.py +4 -3
  42. torchzero/modules/experimental/l_infinity.py +111 -0
  43. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +5 -42
  44. torchzero/modules/experimental/newton_solver.py +79 -17
  45. torchzero/modules/experimental/newtonnewton.py +32 -15
  46. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  47. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  48. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +13 -55
  49. torchzero/modules/functional.py +52 -6
  50. torchzero/modules/grad_approximation/fdm.py +30 -4
  51. torchzero/modules/grad_approximation/forward_gradient.py +16 -4
  52. torchzero/modules/grad_approximation/grad_approximator.py +51 -10
  53. torchzero/modules/grad_approximation/rfdm.py +321 -52
  54. torchzero/modules/higher_order/__init__.py +1 -1
  55. torchzero/modules/higher_order/higher_order_newton.py +164 -93
  56. torchzero/modules/least_squares/__init__.py +1 -0
  57. torchzero/modules/least_squares/gn.py +161 -0
  58. torchzero/modules/line_search/__init__.py +4 -4
  59. torchzero/modules/line_search/_polyinterp.py +289 -0
  60. torchzero/modules/line_search/adaptive.py +124 -0
  61. torchzero/modules/line_search/backtracking.py +95 -57
  62. torchzero/modules/line_search/line_search.py +171 -22
  63. torchzero/modules/line_search/scipy.py +3 -3
  64. torchzero/modules/line_search/strong_wolfe.py +327 -199
  65. torchzero/modules/misc/__init__.py +35 -0
  66. torchzero/modules/misc/debug.py +48 -0
  67. torchzero/modules/misc/escape.py +62 -0
  68. torchzero/modules/misc/gradient_accumulation.py +136 -0
  69. torchzero/modules/misc/homotopy.py +59 -0
  70. torchzero/modules/misc/misc.py +383 -0
  71. torchzero/modules/misc/multistep.py +194 -0
  72. torchzero/modules/misc/regularization.py +167 -0
  73. torchzero/modules/misc/split.py +123 -0
  74. torchzero/modules/{ops → misc}/switch.py +45 -4
  75. torchzero/modules/momentum/__init__.py +1 -5
  76. torchzero/modules/momentum/averaging.py +9 -9
  77. torchzero/modules/momentum/cautious.py +51 -19
  78. torchzero/modules/momentum/momentum.py +37 -2
  79. torchzero/modules/ops/__init__.py +11 -31
  80. torchzero/modules/ops/accumulate.py +6 -10
  81. torchzero/modules/ops/binary.py +81 -34
  82. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +16 -39
  83. torchzero/modules/ops/multi.py +82 -21
  84. torchzero/modules/ops/reduce.py +16 -8
  85. torchzero/modules/ops/unary.py +29 -13
  86. torchzero/modules/ops/utility.py +30 -18
  87. torchzero/modules/projections/__init__.py +2 -4
  88. torchzero/modules/projections/cast.py +51 -0
  89. torchzero/modules/projections/galore.py +3 -1
  90. torchzero/modules/projections/projection.py +190 -96
  91. torchzero/modules/quasi_newton/__init__.py +9 -14
  92. torchzero/modules/quasi_newton/damping.py +105 -0
  93. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -0
  94. torchzero/modules/quasi_newton/lbfgs.py +286 -173
  95. torchzero/modules/quasi_newton/lsr1.py +185 -106
  96. torchzero/modules/quasi_newton/quasi_newton.py +816 -268
  97. torchzero/modules/restarts/__init__.py +7 -0
  98. torchzero/modules/restarts/restars.py +252 -0
  99. torchzero/modules/second_order/__init__.py +3 -2
  100. torchzero/modules/second_order/multipoint.py +238 -0
  101. torchzero/modules/second_order/newton.py +292 -68
  102. torchzero/modules/second_order/newton_cg.py +365 -15
  103. torchzero/modules/second_order/nystrom.py +104 -1
  104. torchzero/modules/smoothing/__init__.py +1 -1
  105. torchzero/modules/smoothing/laplacian.py +14 -4
  106. torchzero/modules/smoothing/sampling.py +300 -0
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +387 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/termination/__init__.py +14 -0
  111. torchzero/modules/termination/termination.py +207 -0
  112. torchzero/modules/trust_region/__init__.py +5 -0
  113. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  114. torchzero/modules/trust_region/dogleg.py +92 -0
  115. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  116. torchzero/modules/trust_region/trust_cg.py +97 -0
  117. torchzero/modules/trust_region/trust_region.py +350 -0
  118. torchzero/modules/variance_reduction/__init__.py +1 -0
  119. torchzero/modules/variance_reduction/svrg.py +208 -0
  120. torchzero/modules/weight_decay/__init__.py +1 -1
  121. torchzero/modules/weight_decay/weight_decay.py +94 -11
  122. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  123. torchzero/modules/zeroth_order/__init__.py +1 -0
  124. torchzero/modules/zeroth_order/cd.py +359 -0
  125. torchzero/optim/root.py +65 -0
  126. torchzero/optim/utility/split.py +8 -8
  127. torchzero/optim/wrappers/directsearch.py +39 -3
  128. torchzero/optim/wrappers/fcmaes.py +24 -15
  129. torchzero/optim/wrappers/mads.py +5 -6
  130. torchzero/optim/wrappers/nevergrad.py +16 -1
  131. torchzero/optim/wrappers/nlopt.py +0 -2
  132. torchzero/optim/wrappers/optuna.py +3 -3
  133. torchzero/optim/wrappers/scipy.py +86 -25
  134. torchzero/utils/__init__.py +40 -4
  135. torchzero/utils/compile.py +1 -1
  136. torchzero/utils/derivatives.py +126 -114
  137. torchzero/utils/linalg/__init__.py +9 -2
  138. torchzero/utils/linalg/linear_operator.py +329 -0
  139. torchzero/utils/linalg/matrix_funcs.py +2 -2
  140. torchzero/utils/linalg/orthogonalize.py +2 -1
  141. torchzero/utils/linalg/qr.py +2 -2
  142. torchzero/utils/linalg/solve.py +369 -58
  143. torchzero/utils/metrics.py +83 -0
  144. torchzero/utils/numberlist.py +2 -0
  145. torchzero/utils/python_tools.py +16 -0
  146. torchzero/utils/tensorlist.py +134 -51
  147. torchzero/utils/torch_tools.py +9 -4
  148. torchzero-0.3.13.dist-info/METADATA +14 -0
  149. torchzero-0.3.13.dist-info/RECORD +166 -0
  150. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  151. docs/source/conf.py +0 -57
  152. torchzero/modules/experimental/absoap.py +0 -250
  153. torchzero/modules/experimental/adadam.py +0 -112
  154. torchzero/modules/experimental/adamY.py +0 -125
  155. torchzero/modules/experimental/adasoap.py +0 -172
  156. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  157. torchzero/modules/experimental/eigendescent.py +0 -117
  158. torchzero/modules/experimental/etf.py +0 -172
  159. torchzero/modules/experimental/soapy.py +0 -163
  160. torchzero/modules/experimental/structured_newton.py +0 -111
  161. torchzero/modules/experimental/subspace_preconditioners.py +0 -138
  162. torchzero/modules/experimental/tada.py +0 -38
  163. torchzero/modules/line_search/trust_region.py +0 -73
  164. torchzero/modules/lr/__init__.py +0 -2
  165. torchzero/modules/lr/adaptive.py +0 -93
  166. torchzero/modules/lr/lr.py +0 -63
  167. torchzero/modules/momentum/matrix_momentum.py +0 -166
  168. torchzero/modules/ops/debug.py +0 -25
  169. torchzero/modules/ops/misc.py +0 -418
  170. torchzero/modules/ops/split.py +0 -75
  171. torchzero/modules/optimizers/__init__.py +0 -18
  172. torchzero/modules/optimizers/adagrad.py +0 -155
  173. torchzero/modules/optimizers/sophia_h.py +0 -129
  174. torchzero/modules/quasi_newton/cg.py +0 -268
  175. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  176. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +0 -266
  177. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  178. torchzero/modules/smoothing/gaussian.py +0 -164
  179. torchzero-0.3.10.dist-info/METADATA +0 -379
  180. torchzero-0.3.10.dist-info/RECORD +0 -139
  181. torchzero-0.3.10.dist-info/licenses/LICENSE +0 -21
  182. {torchzero-0.3.10.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -1,174 +1,253 @@
1
1
  from collections import deque
2
+ from collections.abc import Sequence
2
3
  from operator import itemgetter
3
4
 
4
5
  import torch
5
6
 
6
7
  from ...core import Chainable, Module, Transform, Var, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist
8
-
9
- from .lbfgs import _lerp_params_update_
10
-
11
- def lsr1_(
12
- tensors_: TensorList,
13
- s_history: deque[TensorList],
14
- y_history: deque[TensorList],
15
- step: int,
16
- scale_second: bool,
17
- ):
18
- if step == 0 or not s_history:
19
- # initial step size guess from pytorch
20
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
21
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
22
- return tensors_.mul_(scale_factor)
8
+ from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
9
+ from ...utils.linalg.linear_operator import LinearOperator
10
+ from ..functional import initial_step_size
11
+ from .damping import DampingStrategyType, apply_damping
23
12
 
13
+
14
+ def lsr1_Hx(x, s_history: Sequence, y_history: Sequence,):
24
15
  m = len(s_history)
16
+ if m == 0: return x.clone()
17
+ eps = generic_finfo_tiny(x) * 2
25
18
 
26
- w_list: list[TensorList] = []
27
- ww_list: list = [None for _ in range(m)]
19
+ w_list = []
28
20
  wy_list: list = [None for _ in range(m)]
29
21
 
30
- # 1st loop - all w_k = s_k - H_k_prev y_k
22
+ # # 1st loop - all w_k = s_k - H_k_prev y_k
31
23
  for k in range(m):
32
24
  s_k = s_history[k]
33
25
  y_k = y_history[k]
34
26
 
35
- H_k = y_k.clone()
27
+ Hx = y_k.clone()
36
28
  for j in range(k):
37
29
  w_j = w_list[j]
38
30
  y_j = y_history[j]
39
31
 
40
32
  wy = wy_list[j]
41
33
  if wy is None: wy = wy_list[j] = w_j.dot(y_j)
34
+ if wy.abs() < eps: continue
42
35
 
43
- ww = ww_list[j]
44
- if ww is None: ww = ww_list[j] = w_j.dot(w_j)
45
-
46
- if wy == 0: continue
47
-
48
- H_k.add_(w_j, alpha=w_j.dot(y_k) / wy) # pyright:ignore[reportArgumentType]
36
+ alpha = w_j.dot(y_k) / wy
37
+ Hx.add_(w_j, alpha=alpha)
49
38
 
50
- w_k = s_k - H_k
39
+ w_k = s_k - Hx
51
40
  w_list.append(w_k)
52
41
 
53
- Hx = tensors_.clone()
42
+ Hx = x.clone()
43
+
44
+ # second loop
54
45
  for k in range(m):
55
46
  w_k = w_list[k]
56
47
  y_k = y_history[k]
57
48
  wy = wy_list[k]
58
- ww = ww_list[k]
59
49
 
60
50
  if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
61
- if ww is None: ww = w_k.dot(w_k)
51
+ if wy.abs() < eps: continue
52
+
53
+ alpha = w_k.dot(x) / wy
54
+ Hx.add_(w_k, alpha=alpha)
62
55
 
63
- if wy == 0: continue
56
+ return Hx
64
57
 
65
- Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
58
+ def lsr1_Bx(x, s_history: Sequence, y_history: Sequence,):
59
+ return lsr1_Hx(x, s_history=y_history, y_history=s_history)
66
60
 
67
- if scale_second and step == 1:
68
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
69
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
70
- Hx.mul_(scale_factor)
61
+ class LSR1LinearOperator(LinearOperator):
62
+ def __init__(self, s_history: Sequence[torch.Tensor], y_history: Sequence[torch.Tensor]):
63
+ super().__init__()
64
+ self.s_history = s_history
65
+ self.y_history = y_history
66
+
67
+ def solve(self, b):
68
+ return lsr1_Hx(x=b, s_history=self.s_history, y_history=self.y_history)
69
+
70
+ def matvec(self, x):
71
+ return lsr1_Bx(x=x, s_history=self.s_history, y_history=self.y_history)
72
+
73
+ def size(self):
74
+ if len(self.s_history) == 0: raise RuntimeError()
75
+ n = len(self.s_history[0])
76
+ return (n, n)
71
77
 
72
- return Hx
73
78
 
79
+ class LSR1(Transform):
80
+ """Limited-memory SR1 algorithm. A line search or trust region is recommended.
74
81
 
75
- class LSR1(Module):
76
- """Limited Memory SR1 (L-SR1)
77
82
  Args:
78
- history_size (int, optional): Number of past parameter differences (s)
79
- and gradient differences (y) to store. Defaults to 10.
80
- skip_R_val (float, optional): Tolerance R for the SR1 update skip condition
81
- |w_k^T y_k| >= R * ||w_k|| * ||y_k||. Defaults to 1e-8.
82
- Updates where this condition is not met are skipped during history accumulation
83
- and matrix-vector products.
84
- params_beta (float | None, optional): If not None, EMA of parameters is used for
85
- preconditioner update (s_k vector). Defaults to None.
86
- grads_beta (float | None, optional): If not None, EMA of gradients is used for
87
- preconditioner update (y_k vector). Defaults to None.
88
- update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
89
- conv_tol (float | None, optional): Tolerance for y_k norm. If max abs value of y_k
90
- is below this, the preconditioning step might be skipped, assuming convergence.
91
- Defaults to 1e-10.
92
- inner (Chainable | None, optional): Optional inner modules applied after updating
93
- L-SR1 history and before preconditioning. Defaults to None.
83
+ history_size (int, optional):
84
+ number of past parameter differences and gradient differences to store. Defaults to 10.
85
+ ptol (float | None, optional):
86
+ skips updating the history if maximum absolute value of
87
+ parameter difference is less than this value. Defaults to None.
88
+ ptol_restart (bool, optional):
89
+ If true, whenever parameter difference is less then ``ptol``,
90
+ L-SR1 state will be reset. Defaults to None.
91
+ gtol (float | None, optional):
92
+ skips updating the history if if maximum absolute value of
93
+ gradient difference is less than this value. Defaults to None.
94
+ ptol_restart (bool, optional):
95
+ If true, whenever gradient difference is less then ``gtol``,
96
+ L-SR1 state will be reset. Defaults to None.
97
+ scale_first (bool, optional):
98
+ makes first step, when hessian approximation is not available,
99
+ small to reduce number of line search iterations. Defaults to False.
100
+ update_freq (int, optional):
101
+ how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
102
+ damping (DampingStrategyType, optional):
103
+ damping to use, can be "powell" or "double". Defaults to None.
104
+ compact (bool, optional):
105
+ if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
106
+ inner (Chainable | None, optional):
107
+ optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.
108
+
109
+ ## Examples:
110
+
111
+ L-SR1 with line search
112
+ ```python
113
+ opt = tz.Modular(
114
+ model.parameters(),
115
+ tz.m.SR1(),
116
+ tz.m.StrongWolfe(c2=0.1, fallback=True)
117
+ )
118
+ ```
119
+
120
+ L-SR1 with trust region
121
+ ```python
122
+ opt = tz.Modular(
123
+ model.parameters(),
124
+ tz.m.TrustCG(tz.m.LSR1())
125
+ )
126
+ ```
94
127
  """
95
128
  def __init__(
96
129
  self,
97
- history_size: int = 10,
98
- tol: float = 1e-8,
99
- params_beta: float | None = None,
100
- grads_beta: float | None = None,
101
- update_freq: int = 1,
102
- scale_second: bool = True,
130
+ history_size=10,
131
+ ptol: float | None = None,
132
+ ptol_restart: bool = False,
133
+ gtol: float | None = None,
134
+ gtol_restart: bool = False,
135
+ scale_first:bool=False,
136
+ update_freq = 1,
137
+ damping: DampingStrategyType = None,
103
138
  inner: Chainable | None = None,
104
139
  ):
105
140
  defaults = dict(
106
- history_size=history_size, tol=tol,
107
- params_beta=params_beta, grads_beta=grads_beta,
108
- update_freq=update_freq, scale_second=scale_second
141
+ history_size=history_size,
142
+ scale_first=scale_first,
143
+ ptol=ptol,
144
+ gtol=gtol,
145
+ ptol_restart=ptol_restart,
146
+ gtol_restart=gtol_restart,
147
+ damping = damping,
109
148
  )
110
- super().__init__(defaults)
149
+ super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
111
150
 
112
151
  self.global_state['s_history'] = deque(maxlen=history_size)
113
152
  self.global_state['y_history'] = deque(maxlen=history_size)
114
153
 
115
- if inner is not None:
116
- self.set_child('inner', inner)
117
-
118
- def reset(self):
154
+ def _reset_self(self):
119
155
  self.state.clear()
120
156
  self.global_state['step'] = 0
121
157
  self.global_state['s_history'].clear()
122
158
  self.global_state['y_history'].clear()
123
159
 
160
+ def reset(self):
161
+ self._reset_self()
162
+ for c in self.children.values(): c.reset()
163
+
164
+ def reset_for_online(self):
165
+ super().reset_for_online()
166
+ self.clear_state_keys('p_prev', 'g_prev')
167
+ self.global_state.pop('step', None)
124
168
 
125
169
  @torch.no_grad
126
- def step(self, var: Var):
127
- params = as_tensorlist(var.params)
128
- update = as_tensorlist(var.get_update())
170
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
171
+ p = as_tensorlist(params)
172
+ g = as_tensorlist(tensors)
129
173
  step = self.global_state.get('step', 0)
130
174
  self.global_state['step'] = step + 1
131
175
 
132
- s_history: deque[TensorList] = self.global_state['s_history']
133
- y_history: deque[TensorList] = self.global_state['y_history']
134
-
135
- settings = self.settings[params[0]]
136
- tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
176
+ # history of s and k
177
+ s_history: deque = self.global_state['s_history']
178
+ y_history: deque = self.global_state['y_history']
179
+
180
+ ptol = self.defaults['ptol']
181
+ gtol = self.defaults['gtol']
182
+ ptol_restart = self.defaults['ptol_restart']
183
+ gtol_restart = self.defaults['gtol_restart']
184
+ damping = self.defaults['damping']
185
+
186
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
187
+
188
+ # 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
189
+ if step == 0:
190
+ s = None; y = None; sy = None
191
+ else:
192
+ s = p - p_prev
193
+ y = g - g_prev
194
+
195
+ if damping is not None:
196
+ s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
197
+
198
+ sy = s.dot(y)
199
+ # damping to be added here
200
+
201
+ below_tol = False
202
+ # tolerance on parameter difference to avoid exploding after converging
203
+ if ptol is not None:
204
+ if s is not None and s.abs().global_max() <= ptol:
205
+ if ptol_restart: self._reset_self()
206
+ sy = None
207
+ below_tol = True
208
+
209
+ # tolerance on gradient difference to avoid exploding when there is no curvature
210
+ if gtol is not None:
211
+ if y is not None and y.abs().global_max() <= gtol:
212
+ if gtol_restart: self._reset_self()
213
+ sy = None
214
+ below_tol = True
215
+
216
+ # store previous params and grads
217
+ if not below_tol:
218
+ p_prev.copy_(p)
219
+ g_prev.copy_(g)
220
+
221
+ # update effective preconditioning state
222
+ if sy is not None:
223
+ assert s is not None and y is not None and sy is not None
224
+
225
+ s_history.append(s)
226
+ y_history.append(y)
227
+
228
+ def get_H(self, var=...):
229
+ s_history = [tl.to_vec() for tl in self.global_state['s_history']]
230
+ y_history = [tl.to_vec() for tl in self.global_state['y_history']]
231
+ return LSR1LinearOperator(s_history, y_history)
137
232
 
138
- params_beta, grads_beta_ = self.get_settings(params, 'params_beta', 'grads_beta') # type: ignore
139
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
140
-
141
- prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
142
-
143
- y_k = None
144
- if step != 0:
145
- if step % update_freq == 0:
146
- s_k = l_params - prev_l_params
147
- y_k = l_update - prev_l_grad
148
-
149
- s_history.append(s_k)
150
- y_history.append(y_k)
151
-
152
- prev_l_params.copy_(l_params)
153
- prev_l_grad.copy_(l_update)
233
+ @torch.no_grad
234
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
235
+ scale_first = self.defaults['scale_first']
154
236
 
155
- if 'inner' in self.children:
156
- update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
237
+ tensors = as_tensorlist(tensors)
157
238
 
158
- # tolerance on gradient difference to avoid exploding after converging
159
- if tol is not None:
160
- if y_k is not None and y_k.abs().global_max() <= tol:
161
- var.update = update
162
- return var
239
+ s_history = self.global_state['s_history']
240
+ y_history = self.global_state['y_history']
163
241
 
164
- dir = lsr1_(
165
- tensors_=update,
242
+ # precondition
243
+ dir = lsr1_Hx(
244
+ x=tensors,
166
245
  s_history=s_history,
167
246
  y_history=y_history,
168
- step=step,
169
- scale_second=scale_second,
170
247
  )
171
248
 
172
- var.update = dir
249
+ # scale 1st step
250
+ if scale_first and self.global_state.get('step', 1) == 1:
251
+ dir *= initial_step_size(dir, eps=1e-7)
173
252
 
174
- return var
253
+ return dir