torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -1,218 +1,253 @@
1
1
  from collections import deque
2
+ from collections.abc import Sequence
2
3
  from operator import itemgetter
3
4
 
4
5
  import torch
5
6
 
6
7
  from ...core import Chainable, Module, Transform, Var, apply_transform
7
- from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
- from ..functional import safe_scaling_
9
- from .lbfgs import _lerp_params_update_
10
-
11
-
12
- def lsr1_(
13
- tensors_: TensorList,
14
- s_history: deque[TensorList],
15
- y_history: deque[TensorList],
16
- step: int,
17
- scale_second: bool,
18
- ):
19
- if len(s_history) == 0:
20
- # initial step size guess from pytorch
21
- return safe_scaling_(TensorList(tensors_))
8
+ from ...utils import NumberList, TensorList, as_tensorlist, generic_finfo_tiny, unpack_states, vec_to_tensors_
9
+ from ...utils.linalg.linear_operator import LinearOperator
10
+ from ..functional import initial_step_size
11
+ from .damping import DampingStrategyType, apply_damping
22
12
 
13
+
14
+ def lsr1_Hx(x, s_history: Sequence, y_history: Sequence,):
23
15
  m = len(s_history)
16
+ if m == 0: return x.clone()
17
+ eps = generic_finfo_tiny(x) * 2
24
18
 
25
- w_list: list[TensorList] = []
26
- ww_list: list = [None for _ in range(m)]
19
+ w_list = []
27
20
  wy_list: list = [None for _ in range(m)]
28
21
 
29
- # 1st loop - all w_k = s_k - H_k_prev y_k
22
+ # # 1st loop - all w_k = s_k - H_k_prev y_k
30
23
  for k in range(m):
31
24
  s_k = s_history[k]
32
25
  y_k = y_history[k]
33
26
 
34
- H_k = y_k.clone()
27
+ Hx = y_k.clone()
35
28
  for j in range(k):
36
29
  w_j = w_list[j]
37
30
  y_j = y_history[j]
38
31
 
39
32
  wy = wy_list[j]
40
33
  if wy is None: wy = wy_list[j] = w_j.dot(y_j)
34
+ if wy.abs() < eps: continue
41
35
 
42
- ww = ww_list[j]
43
- if ww is None: ww = ww_list[j] = w_j.dot(w_j)
44
-
45
- if wy == 0: continue
36
+ alpha = w_j.dot(y_k) / wy
37
+ Hx.add_(w_j, alpha=alpha)
46
38
 
47
- H_k.add_(w_j, alpha=w_j.dot(y_k) / wy) # pyright:ignore[reportArgumentType]
48
-
49
- w_k = s_k - H_k
39
+ w_k = s_k - Hx
50
40
  w_list.append(w_k)
51
41
 
52
- Hx = tensors_.clone()
42
+ Hx = x.clone()
43
+
44
+ # second loop
53
45
  for k in range(m):
54
46
  w_k = w_list[k]
55
47
  y_k = y_history[k]
56
48
  wy = wy_list[k]
57
- ww = ww_list[k]
58
49
 
59
50
  if wy is None: wy = w_k.dot(y_k) # this happens when m = 1 so inner loop doesn't run
60
- if ww is None: ww = w_k.dot(w_k)
51
+ if wy.abs() < eps: continue
61
52
 
62
- if wy == 0: continue
53
+ alpha = w_k.dot(x) / wy
54
+ Hx.add_(w_k, alpha=alpha)
63
55
 
64
- Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
56
+ return Hx
65
57
 
66
- if scale_second and step == 2:
67
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
68
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
69
- Hx.mul_(scale_factor)
58
+ def lsr1_Bx(x, s_history: Sequence, y_history: Sequence,):
59
+ return lsr1_Hx(x, s_history=y_history, y_history=s_history)
70
60
 
71
- return Hx
61
+ class LSR1LinearOperator(LinearOperator):
62
+ def __init__(self, s_history: Sequence[torch.Tensor], y_history: Sequence[torch.Tensor]):
63
+ super().__init__()
64
+ self.s_history = s_history
65
+ self.y_history = y_history
72
66
 
67
+ def solve(self, b):
68
+ return lsr1_Hx(x=b, s_history=self.s_history, y_history=self.y_history)
73
69
 
74
- class LSR1(Transform):
75
- """Limited Memory SR1 algorithm. A line search is recommended.
70
+ def matvec(self, x):
71
+ return lsr1_Bx(x=x, s_history=self.s_history, y_history=self.y_history)
76
72
 
77
- .. note::
78
- L-SR1 provides a better estimate of true hessian, however it is more unstable compared to L-BFGS.
73
+ def size(self):
74
+ if len(self.s_history) == 0: raise RuntimeError()
75
+ n = len(self.s_history[0])
76
+ return (n, n)
79
77
 
80
- .. note::
81
- L-SR1 update rule uses a nested loop, computationally with history size `n` it is similar to L-BFGS with history size `(n^2)/2`. On small problems (ndim <= 2000) BFGS and SR1 may be faster than limited-memory versions.
82
78
 
83
- .. note::
84
- directions L-SR1 generates are not guaranteed to be descent directions. This can be alleviated in multiple ways,
85
- for example using :code:`tz.m.StrongWolfe(plus_minus=True)` line search, or modifying the direction with :code:`tz.m.Cautious` or :code:`tz.m.ScaleByGradCosineSimilarity`.
79
+ class LSR1(Transform):
80
+ """Limited-memory SR1 algorithm. A line search or trust region is recommended.
86
81
 
87
82
  Args:
88
83
  history_size (int, optional):
89
84
  number of past parameter differences and gradient differences to store. Defaults to 10.
90
- tol (float | None, optional):
91
- tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
92
- tol_reset (bool, optional):
93
- If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
85
+ ptol (float | None, optional):
86
+ skips updating the history if maximum absolute value of
87
+ parameter difference is less than this value. Defaults to None.
88
+ ptol_restart (bool, optional):
89
+ If true, whenever parameter difference is less then ``ptol``,
90
+ L-SR1 state will be reset. Defaults to None.
94
91
  gtol (float | None, optional):
95
- tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
96
- params_beta (float | None, optional):
97
- if not None, EMA of parameters is used for
98
- preconditioner update (s_k vector). Defaults to None.
99
- grads_beta (float | None, optional):
100
- if not None, EMA of gradients is used for
101
- preconditioner update (y_k vector). Defaults to None.
102
- update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
103
- scale_second (bool, optional): downscales second update which tends to be large. Defaults to False.
92
+ skips updating the history if if maximum absolute value of
93
+ gradient difference is less than this value. Defaults to None.
94
+ ptol_restart (bool, optional):
95
+ If true, whenever gradient difference is less then ``gtol``,
96
+ L-SR1 state will be reset. Defaults to None.
97
+ scale_first (bool, optional):
98
+ makes first step, when hessian approximation is not available,
99
+ small to reduce number of line search iterations. Defaults to False.
100
+ update_freq (int, optional):
101
+ how often to update L-SR1 history. Larger values may be better for stochastic optimization. Defaults to 1.
102
+ damping (DampingStrategyType, optional):
103
+ damping to use, can be "powell" or "double". Defaults to None.
104
+ compact (bool, optional):
105
+ if True, uses a compact representation verstion of L-SR1. It is much faster computationally, but less stable.
104
106
  inner (Chainable | None, optional):
105
- Optional inner modules applied after updating
106
- L-SR1 history and before preconditioning. Defaults to None.
107
-
108
- Examples:
109
- L-SR1 with Strong-Wolfe+- line search
110
-
111
- .. code-block:: python
112
-
113
- opt = tz.Modular(
114
- model.parameters(),
115
- tz.m.LSR1(100),
116
- tz.m.StrongWolfe(plus_minus=True)
117
- )
107
+ optional inner modules applied after updating L-SR1 history and before preconditioning. Defaults to None.
108
+
109
+ ## Examples:
110
+
111
+ L-SR1 with line search
112
+ ```python
113
+ opt = tz.Modular(
114
+ model.parameters(),
115
+ tz.m.SR1(),
116
+ tz.m.StrongWolfe(c2=0.1, fallback=True)
117
+ )
118
+ ```
119
+
120
+ L-SR1 with trust region
121
+ ```python
122
+ opt = tz.Modular(
123
+ model.parameters(),
124
+ tz.m.TrustCG(tz.m.LSR1())
125
+ )
126
+ ```
118
127
  """
119
128
  def __init__(
120
129
  self,
121
- history_size: int = 10,
122
- tol: float | None = 1e-10,
123
- tol_reset: bool = False,
124
- gtol: float | None = 1e-10,
125
- params_beta: float | None = None,
126
- grads_beta: float | None = None,
127
- update_freq: int = 1,
128
- scale_second: bool = False,
130
+ history_size=10,
131
+ ptol: float | None = None,
132
+ ptol_restart: bool = False,
133
+ gtol: float | None = None,
134
+ gtol_restart: bool = False,
135
+ scale_first:bool=False,
136
+ update_freq = 1,
137
+ damping: DampingStrategyType = None,
129
138
  inner: Chainable | None = None,
130
139
  ):
131
140
  defaults = dict(
132
- history_size=history_size, tol=tol, gtol=gtol,
133
- params_beta=params_beta, grads_beta=grads_beta,
134
- update_freq=update_freq, scale_second=scale_second,
135
- tol_reset=tol_reset,
141
+ history_size=history_size,
142
+ scale_first=scale_first,
143
+ ptol=ptol,
144
+ gtol=gtol,
145
+ ptol_restart=ptol_restart,
146
+ gtol_restart=gtol_restart,
147
+ damping = damping,
136
148
  )
137
- super().__init__(defaults, uses_grad=False, inner=inner)
149
+ super().__init__(defaults, uses_grad=False, inner=inner, update_freq=update_freq)
138
150
 
139
151
  self.global_state['s_history'] = deque(maxlen=history_size)
140
152
  self.global_state['y_history'] = deque(maxlen=history_size)
141
153
 
142
- def reset(self):
154
+ def _reset_self(self):
143
155
  self.state.clear()
144
156
  self.global_state['step'] = 0
145
157
  self.global_state['s_history'].clear()
146
158
  self.global_state['y_history'].clear()
147
159
 
160
+ def reset(self):
161
+ self._reset_self()
162
+ for c in self.children.values(): c.reset()
163
+
148
164
  def reset_for_online(self):
149
165
  super().reset_for_online()
150
- self.clear_state_keys('prev_l_params', 'prev_l_grad')
166
+ self.clear_state_keys('p_prev', 'g_prev')
151
167
  self.global_state.pop('step', None)
152
168
 
153
169
  @torch.no_grad
154
170
  def update_tensors(self, tensors, params, grads, loss, states, settings):
155
- params = as_tensorlist(params)
156
- update = as_tensorlist(tensors)
171
+ p = as_tensorlist(params)
172
+ g = as_tensorlist(tensors)
157
173
  step = self.global_state.get('step', 0)
158
174
  self.global_state['step'] = step + 1
159
175
 
160
- s_history: deque[TensorList] = self.global_state['s_history']
161
- y_history: deque[TensorList] = self.global_state['y_history']
176
+ # history of s and k
177
+ s_history: deque = self.global_state['s_history']
178
+ y_history: deque = self.global_state['y_history']
162
179
 
163
- setting = settings[0]
164
- update_freq = itemgetter('update_freq')(setting)
180
+ ptol = self.defaults['ptol']
181
+ gtol = self.defaults['gtol']
182
+ ptol_restart = self.defaults['ptol_restart']
183
+ gtol_restart = self.defaults['gtol_restart']
184
+ damping = self.defaults['damping']
165
185
 
166
- params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
167
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
168
- prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
186
+ p_prev, g_prev = unpack_states(states, tensors, 'p_prev', 'g_prev', cls=TensorList)
169
187
 
170
- s = None
171
- y = None
172
- if step != 0:
173
- if step % update_freq == 0:
174
- s = l_params - prev_l_params
175
- y = l_update - prev_l_grad
188
+ # 1st step - there are no previous params and grads, lsr1 will do normalized SGD step
189
+ if step == 0:
190
+ s = None; y = None; sy = None
191
+ else:
192
+ s = p - p_prev
193
+ y = g - g_prev
176
194
 
177
- s_history.append(s)
178
- y_history.append(y)
195
+ if damping is not None:
196
+ s, y = apply_damping(damping, s=s, y=y, g=g, H=self.get_H())
179
197
 
180
- prev_l_params.copy_(l_params)
181
- prev_l_grad.copy_(l_update)
198
+ sy = s.dot(y)
199
+ # damping to be added here
182
200
 
183
- # store for apply
184
- self.global_state['s'] = s
185
- self.global_state['y'] = y
201
+ below_tol = False
202
+ # tolerance on parameter difference to avoid exploding after converging
203
+ if ptol is not None:
204
+ if s is not None and s.abs().global_max() <= ptol:
205
+ if ptol_restart: self._reset_self()
206
+ sy = None
207
+ below_tol = True
208
+
209
+ # tolerance on gradient difference to avoid exploding when there is no curvature
210
+ if gtol is not None:
211
+ if y is not None and y.abs().global_max() <= gtol:
212
+ if gtol_restart: self._reset_self()
213
+ sy = None
214
+ below_tol = True
215
+
216
+ # store previous params and grads
217
+ if not below_tol:
218
+ p_prev.copy_(p)
219
+ g_prev.copy_(g)
220
+
221
+ # update effective preconditioning state
222
+ if sy is not None:
223
+ assert s is not None and y is not None and sy is not None
224
+
225
+ s_history.append(s)
226
+ y_history.append(y)
227
+
228
+ def get_H(self, var=...):
229
+ s_history = [tl.to_vec() for tl in self.global_state['s_history']]
230
+ y_history = [tl.to_vec() for tl in self.global_state['y_history']]
231
+ return LSR1LinearOperator(s_history, y_history)
186
232
 
187
233
  @torch.no_grad
188
234
  def apply_tensors(self, tensors, params, grads, loss, states, settings):
189
- tensors = as_tensorlist(tensors)
190
- s = self.global_state.pop('s')
191
- y = self.global_state.pop('y')
235
+ scale_first = self.defaults['scale_first']
192
236
 
193
- setting = settings[0]
194
- tol = setting['tol']
195
- gtol = setting['gtol']
196
- tol_reset = setting['tol_reset']
237
+ tensors = as_tensorlist(tensors)
197
238
 
198
- # tolerance on parameter difference to avoid exploding after converging
199
- if tol is not None:
200
- if s is not None and s.abs().global_max() <= tol:
201
- if tol_reset: self.reset()
202
- return safe_scaling_(TensorList(tensors))
203
-
204
- # tolerance on gradient difference to avoid exploding when there is no curvature
205
- if tol is not None:
206
- if y is not None and y.abs().global_max() <= gtol:
207
- return safe_scaling_(TensorList(tensors))
239
+ s_history = self.global_state['s_history']
240
+ y_history = self.global_state['y_history']
208
241
 
209
242
  # precondition
210
- dir = lsr1_(
211
- tensors_=tensors,
212
- s_history=self.global_state['s_history'],
213
- y_history=self.global_state['y_history'],
214
- step=self.global_state.get('step', 1),
215
- scale_second=setting['scale_second'],
243
+ dir = lsr1_Hx(
244
+ x=tensors,
245
+ s_history=s_history,
246
+ y_history=y_history,
216
247
  )
217
248
 
218
- return dir
249
+ # scale 1st step
250
+ if scale_first and self.global_state.get('step', 1) == 1:
251
+ dir *= initial_step_size(dir, eps=1e-7)
252
+
253
+ return dir