torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,77 +1,76 @@
1
1
  from collections import deque
2
2
  from operator import itemgetter
3
+
3
4
  import torch
4
5
 
5
- from ...core import Transform, Chainable, Module, Vars, apply
6
- from ...utils import TensorList, as_tensorlist, NumberList
6
+ from ...core import Chainable, Module, Transform, Var, apply_transform
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+ from ..functional import safe_scaling_
7
9
 
8
10
 
9
11
  def _adaptive_damping(
10
- s_k: TensorList,
11
- y_k: TensorList,
12
- ys_k: torch.Tensor,
12
+ s: TensorList,
13
+ y: TensorList,
14
+ sy: torch.Tensor,
13
15
  init_damping = 0.99,
14
16
  eigval_bounds = (0.01, 1.5)
15
17
  ):
16
18
  # adaptive damping Al-Baali, M.: Quasi-Wolfe conditions for quasi-Newton methods for large-scale optimization. In: 40th Workshop on Large Scale Nonlinear Optimization, Erice, Italy, June 22–July 1 (2004)
17
19
  sigma_l, sigma_h = eigval_bounds
18
- u = ys_k / s_k.dot(s_k)
20
+ u = sy / s.dot(s)
19
21
  if u <= sigma_l < 1: tau = min((1-sigma_l)/(1-u), init_damping)
20
22
  elif u >= sigma_h > 1: tau = min((sigma_h-1)/(u-1), init_damping)
21
23
  else: tau = init_damping
22
- y_k = tau * y_k + (1-tau) * s_k
23
- ys_k = s_k.dot(y_k)
24
+ y = tau * y + (1-tau) * s
25
+ sy = s.dot(y)
24
26
 
25
- return s_k, y_k, ys_k
27
+ return s, y, sy
26
28
 
27
29
  def lbfgs(
28
30
  tensors_: TensorList,
29
31
  s_history: deque[TensorList],
30
32
  y_history: deque[TensorList],
31
33
  sy_history: deque[torch.Tensor],
32
- y_k: TensorList | None,
33
- ys_k: torch.Tensor | None,
34
+ y: TensorList | None,
35
+ sy: torch.Tensor | None,
34
36
  z_beta: float | None,
35
37
  z_ema: TensorList | None,
36
38
  step: int,
37
39
  ):
38
- if len(s_history) == 0 or y_k is None or ys_k is None:
40
+ if len(s_history) == 0 or y is None or sy is None:
39
41
 
40
42
  # initial step size guess modified from pytorch L-BFGS
41
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
42
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
43
- return tensors_.mul_(scale_factor)
44
-
45
- else:
46
- # 1st loop
47
- alpha_list = []
48
- q = tensors_.clone()
49
- for s_i, y_i, ys_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
50
- p_i = 1 / ys_i # this is also denoted as ρ (rho)
51
- alpha = p_i * s_i.dot(q)
52
- alpha_list.append(alpha)
53
- q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
54
-
55
- # calculate z
56
- # s.y/y.y is also this weird y-looking symbol I couldn't find
57
- # z is it times q
58
- # actually H0 = (s.y/y.y) * I, and z = H0 @ q
59
- z = q * (ys_k / (y_k.dot(y_k)))
60
-
61
- # an attempt into adding momentum, lerping initial z seems stable compared to other variables
62
- if z_beta is not None:
63
- assert z_ema is not None
64
- if step == 0: z_ema.copy_(z)
65
- else: z_ema.lerp(z, 1-z_beta)
66
- z = z_ema
67
-
68
- # 2nd loop
69
- for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
70
- p_i = 1 / ys_i
71
- beta_i = p_i * y_i.dot(z)
72
- z.add_(s_i, alpha = alpha_i - beta_i)
73
-
74
- return z
43
+ return safe_scaling_(TensorList(tensors_))
44
+
45
+ # 1st loop
46
+ alpha_list = []
47
+ q = tensors_.clone()
48
+ for s_i, y_i, sy_i in zip(reversed(s_history), reversed(y_history), reversed(sy_history)):
49
+ p_i = 1 / sy_i # this is also denoted as ρ (rho)
50
+ alpha = p_i * s_i.dot(q)
51
+ alpha_list.append(alpha)
52
+ q.sub_(y_i, alpha=alpha) # pyright: ignore[reportArgumentType]
53
+
54
+ # calculate z
55
+ # s.y/y.y is also this weird y-looking symbol I couldn't find
56
+ # z is it times q
57
+ # actually H0 = (s.y/y.y) * I, and z = H0 @ q
58
+ z = q * (sy / (y.dot(y)))
59
+
60
+ # an attempt into adding momentum, lerping initial z seems stable compared to other variables
61
+ if z_beta is not None:
62
+ assert z_ema is not None
63
+ if step == 1: z_ema.copy_(z)
64
+ else: z_ema.lerp(z, 1-z_beta)
65
+ z = z_ema
66
+
67
+ # 2nd loop
68
+ for s_i, y_i, sy_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
69
+ p_i = 1 / sy_i
70
+ beta_i = p_i * y_i.dot(z)
71
+ z.add_(s_i, alpha = alpha_i - beta_i)
72
+
73
+ return z
75
74
 
76
75
  def _lerp_params_update_(
77
76
  self_: Module,
@@ -96,19 +95,24 @@ def _lerp_params_update_(
96
95
 
97
96
  return TensorList(params), TensorList(update)
98
97
 
99
- class LBFGS(Module):
100
- """L-BFGS
98
+ class LBFGS(Transform):
99
+ """Limited-memory BFGS algorithm. A line search is recommended, although L-BFGS may be reasonably stable without it.
101
100
 
102
101
  Args:
103
- history_size (int, optional): number of past parameter differences and gradient differences to store. Defaults to 10.
104
- tol (float | None, optional):
105
- tolerance for minimal gradient difference to avoid instability after converging to minima. Defaults to 1e-10.
102
+ history_size (int, optional):
103
+ number of past parameter differences and gradient differences to store. Defaults to 10.
106
104
  damping (bool, optional):
107
105
  whether to use adaptive damping. Learning rate might need to be lowered with this enabled. Defaults to False.
108
106
  init_damping (float, optional):
109
107
  initial damping for adaptive dampening. Defaults to 0.9.
110
108
  eigval_bounds (tuple, optional):
111
109
  eigenvalue bounds for adaptive dampening. Defaults to (0.5, 50).
110
+ tol (float | None, optional):
111
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
112
+ tol_reset (bool, optional):
113
+ If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
114
+ gtol (float | None, optional):
115
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
112
116
  params_beta (float | None, optional):
113
117
  if not None, EMA of parameters is used for preconditioner update. Defaults to None.
114
118
  grads_beta (float | None, optional):
@@ -117,35 +121,62 @@ class LBFGS(Module):
117
121
  how often to update L-BFGS history. Defaults to 1.
118
122
  z_beta (float | None, optional):
119
123
  optional EMA for initial H^-1 @ q. Acts as a kind of momentum but is prone to get stuck. Defaults to None.
120
- tol_reset (bool, optional):
121
- If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
122
124
  inner (Chainable | None, optional):
123
125
  optional inner modules applied after updating L-BFGS history and before preconditioning. Defaults to None.
126
+
127
+ Examples:
128
+ L-BFGS with strong-wolfe line search
129
+
130
+ .. code-block:: python
131
+
132
+ opt = tz.Modular(
133
+ model.parameters(),
134
+ tz.m.LBFGS(100),
135
+ tz.m.StrongWolfe()
136
+ )
137
+
138
+ Dampened L-BFGS
139
+
140
+ .. code-block:: python
141
+
142
+ opt = tz.Modular(
143
+ model.parameters(),
144
+ tz.m.LBFGS(damping=True),
145
+ tz.m.StrongWolfe()
146
+ )
147
+
148
+ L-BFGS preconditioning applied to momentum (may be unstable!)
149
+
150
+ .. code-block:: python
151
+
152
+ opt = tz.Modular(
153
+ model.parameters(),
154
+ tz.m.LBFGS(inner=tz.m.EMA(0.9)),
155
+ tz.m.LR(1e-2)
156
+ )
124
157
  """
125
158
  def __init__(
126
159
  self,
127
160
  history_size=10,
128
- tol: float | None = 1e-10,
129
161
  damping: bool = False,
130
162
  init_damping=0.9,
131
163
  eigval_bounds=(0.5, 50),
164
+ tol: float | None = 1e-10,
165
+ tol_reset: bool = False,
166
+ gtol: float | None = 1e-10,
132
167
  params_beta: float | None = None,
133
168
  grads_beta: float | None = None,
134
169
  update_freq = 1,
135
170
  z_beta: float | None = None,
136
- tol_reset: bool = False,
137
171
  inner: Chainable | None = None,
138
172
  ):
139
- defaults = dict(history_size=history_size, tol=tol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
140
- super().__init__(defaults)
173
+ defaults = dict(history_size=history_size, tol=tol, gtol=gtol, damping=damping, init_damping=init_damping, eigval_bounds=eigval_bounds, params_beta=params_beta, grads_beta=grads_beta, update_freq=update_freq, z_beta=z_beta, tol_reset=tol_reset)
174
+ super().__init__(defaults, uses_grad=False, inner=inner)
141
175
 
142
176
  self.global_state['s_history'] = deque(maxlen=history_size)
143
177
  self.global_state['y_history'] = deque(maxlen=history_size)
144
178
  self.global_state['sy_history'] = deque(maxlen=history_size)
145
179
 
146
- if inner is not None:
147
- self.set_child('inner', inner)
148
-
149
180
  def reset(self):
150
181
  self.state.clear()
151
182
  self.global_state['step'] = 0
@@ -153,10 +184,15 @@ class LBFGS(Module):
153
184
  self.global_state['y_history'].clear()
154
185
  self.global_state['sy_history'].clear()
155
186
 
187
+ def reset_for_online(self):
188
+ super().reset_for_online()
189
+ self.clear_state_keys('prev_l_params', 'prev_l_grad')
190
+ self.global_state.pop('step', None)
191
+
156
192
  @torch.no_grad
157
- def step(self, vars):
158
- params = as_tensorlist(vars.params)
159
- update = as_tensorlist(vars.get_update())
193
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
194
+ params = as_tensorlist(params)
195
+ update = as_tensorlist(tensors)
160
196
  step = self.global_state.get('step', 0)
161
197
  self.global_state['step'] = step + 1
162
198
 
@@ -165,65 +201,86 @@ class LBFGS(Module):
165
201
  y_history: deque[TensorList] = self.global_state['y_history']
166
202
  sy_history: deque[torch.Tensor] = self.global_state['sy_history']
167
203
 
168
- tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
169
- 'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
170
- params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params)
204
+ damping,init_damping,eigval_bounds,update_freq = itemgetter('damping','init_damping','eigval_bounds','update_freq')(settings[0])
205
+ params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
171
206
 
172
207
  l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
173
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
208
+ prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
174
209
 
175
- # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
210
+ # 1st step - there are no previous params and grads, lbfgs will do normalized SGD step
176
211
  if step == 0:
177
- s_k = None; y_k = None; ys_k = None
212
+ s = None; y = None; sy = None
178
213
  else:
179
- s_k = l_params - prev_l_params
180
- y_k = l_update - prev_l_grad
181
- ys_k = s_k.dot(y_k)
214
+ s = l_params - prev_l_params
215
+ y = l_update - prev_l_grad
216
+ sy = s.dot(y)
182
217
 
183
218
  if damping:
184
- s_k, y_k, ys_k = _adaptive_damping(s_k, y_k, ys_k, init_damping=init_damping, eigval_bounds=eigval_bounds)
219
+ s, y, sy = _adaptive_damping(s, y, sy, init_damping=init_damping, eigval_bounds=eigval_bounds)
185
220
 
186
221
  prev_l_params.copy_(l_params)
187
222
  prev_l_grad.copy_(l_update)
188
223
 
189
224
  # update effective preconditioning state
190
225
  if step % update_freq == 0:
191
- if ys_k is not None and ys_k > 1e-10:
192
- assert s_k is not None and y_k is not None
193
- s_history.append(s_k)
194
- y_history.append(y_k)
195
- sy_history.append(ys_k)
226
+ if sy is not None and sy > 1e-10:
227
+ assert s is not None and y is not None
228
+ s_history.append(s)
229
+ y_history.append(y)
230
+ sy_history.append(sy)
231
+
232
+ # store for apply
233
+ self.global_state['s'] = s
234
+ self.global_state['y'] = y
235
+ self.global_state['sy'] = sy
236
+
237
+ def make_Hv(self):
238
+ ...
196
239
 
197
- # step with inner module before applying preconditioner
198
- if self.children:
199
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
240
+ def make_Bv(self):
241
+ ...
200
242
 
201
- # tolerance on gradient difference to avoid exploding after converging
243
+ @torch.no_grad
244
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
245
+ tensors = as_tensorlist(tensors)
246
+
247
+ s = self.global_state.pop('s')
248
+ y = self.global_state.pop('y')
249
+ sy = self.global_state.pop('sy')
250
+
251
+ setting = settings[0]
252
+ tol = setting['tol']
253
+ gtol = setting['gtol']
254
+ tol_reset = setting['tol_reset']
255
+ z_beta = setting['z_beta']
256
+
257
+ # tolerance on parameter difference to avoid exploding after converging
202
258
  if tol is not None:
203
- if y_k is not None and y_k.abs().global_max() <= tol:
204
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
259
+ if s is not None and s.abs().global_max() <= tol:
205
260
  if tol_reset: self.reset()
206
- return vars
261
+ return safe_scaling_(TensorList(tensors))
262
+
263
+ # tolerance on gradient difference to avoid exploding when there is no curvature
264
+ if tol is not None:
265
+ if y is not None and y.abs().global_max() <= gtol:
266
+ return safe_scaling_(TensorList(tensors))
207
267
 
208
268
  # lerp initial H^-1 @ q guess
209
269
  z_ema = None
210
270
  if z_beta is not None:
211
- z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
271
+ z_ema = unpack_states(states, tensors, 'z_ema', cls=TensorList)
212
272
 
213
273
  # precondition
214
274
  dir = lbfgs(
215
- tensors_=as_tensorlist(update),
216
- s_history=s_history,
217
- y_history=y_history,
218
- sy_history=sy_history,
219
- y_k=y_k,
220
- ys_k=ys_k,
275
+ tensors_=tensors,
276
+ s_history=self.global_state['s_history'],
277
+ y_history=self.global_state['y_history'],
278
+ sy_history=self.global_state['sy_history'],
279
+ y=y,
280
+ sy=sy,
221
281
  z_beta = z_beta,
222
282
  z_ema = z_ema,
223
- step=step
283
+ step=self.global_state.get('step', 1)
224
284
  )
225
285
 
226
- vars.update = dir
227
-
228
- return vars
229
-
286
+ return dir
@@ -3,11 +3,12 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, Transform, Vars, apply
7
- from ...utils import NumberList, TensorList, as_tensorlist
8
-
6
+ from ...core import Chainable, Module, Transform, Var, apply_transform
7
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
8
+ from ..functional import safe_scaling_
9
9
  from .lbfgs import _lerp_params_update_
10
10
 
11
+
11
12
  def lsr1_(
12
13
  tensors_: TensorList,
13
14
  s_history: deque[TensorList],
@@ -15,11 +16,9 @@ def lsr1_(
15
16
  step: int,
16
17
  scale_second: bool,
17
18
  ):
18
- if step == 0 or not s_history:
19
+ if len(s_history) == 0:
19
20
  # initial step size guess from pytorch
20
- scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
21
- scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
22
- return tensors_.mul_(scale_factor)
21
+ return safe_scaling_(TensorList(tensors_))
23
22
 
24
23
  m = len(s_history)
25
24
 
@@ -64,7 +63,7 @@ def lsr1_(
64
63
 
65
64
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
66
65
 
67
- if scale_second and step == 1:
66
+ if scale_second and step == 2:
68
67
  scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
69
68
  scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
70
69
  Hx.mul_(scale_factor)
@@ -72,103 +71,148 @@ def lsr1_(
72
71
  return Hx
73
72
 
74
73
 
75
- class LSR1(Module):
76
- """Limited Memory SR1 (L-SR1)
74
+ class LSR1(Transform):
75
+ """Limited Memory SR1 algorithm. A line search is recommended.
76
+
77
+ .. note::
78
+ L-SR1 provides a better estimate of true hessian, however it is more unstable compared to L-BFGS.
79
+
80
+ .. note::
81
+ L-SR1 update rule uses a nested loop, computationally with history size `n` it is similar to L-BFGS with history size `(n^2)/2`. On small problems (ndim <= 2000) BFGS and SR1 may be faster than limited-memory versions.
82
+
83
+ .. note::
84
+ directions L-SR1 generates are not guaranteed to be descent directions. This can be alleviated in multiple ways,
85
+ for example using :code:`tz.m.StrongWolfe(plus_minus=True)` line search, or modifying the direction with :code:`tz.m.Cautious` or :code:`tz.m.ScaleByGradCosineSimilarity`.
86
+
77
87
  Args:
78
- history_size (int, optional): Number of past parameter differences (s)
79
- and gradient differences (y) to store. Defaults to 10.
80
- skip_R_val (float, optional): Tolerance R for the SR1 update skip condition
81
- |w_k^T y_k| >= R * ||w_k|| * ||y_k||. Defaults to 1e-8.
82
- Updates where this condition is not met are skipped during history accumulation
83
- and matrix-vector products.
84
- params_beta (float | None, optional): If not None, EMA of parameters is used for
88
+ history_size (int, optional):
89
+ number of past parameter differences and gradient differences to store. Defaults to 10.
90
+ tol (float | None, optional):
91
+ tolerance for minimal parameter difference to avoid instability. Defaults to 1e-10.
92
+ tol_reset (bool, optional):
93
+ If true, whenever gradient difference is less then `tol`, the history will be reset. Defaults to None.
94
+ gtol (float | None, optional):
95
+ tolerance for minimal gradient difference to avoid instability when there is no curvature. Defaults to 1e-10.
96
+ params_beta (float | None, optional):
97
+ if not None, EMA of parameters is used for
85
98
  preconditioner update (s_k vector). Defaults to None.
86
- grads_beta (float | None, optional): If not None, EMA of gradients is used for
99
+ grads_beta (float | None, optional):
100
+ if not None, EMA of gradients is used for
87
101
  preconditioner update (y_k vector). Defaults to None.
88
102
  update_freq (int, optional): How often to update L-SR1 history. Defaults to 1.
89
- conv_tol (float | None, optional): Tolerance for y_k norm. If max abs value of y_k
90
- is below this, the preconditioning step might be skipped, assuming convergence.
91
- Defaults to 1e-10.
92
- inner (Chainable | None, optional): Optional inner modules applied after updating
103
+ scale_second (bool, optional): downscales second update which tends to be large. Defaults to False.
104
+ inner (Chainable | None, optional):
105
+ Optional inner modules applied after updating
93
106
  L-SR1 history and before preconditioning. Defaults to None.
107
+
108
+ Examples:
109
+ L-SR1 with Strong-Wolfe+- line search
110
+
111
+ .. code-block:: python
112
+
113
+ opt = tz.Modular(
114
+ model.parameters(),
115
+ tz.m.LSR1(100),
116
+ tz.m.StrongWolfe(plus_minus=True)
117
+ )
94
118
  """
95
119
  def __init__(
96
120
  self,
97
121
  history_size: int = 10,
98
- tol: float = 1e-8,
122
+ tol: float | None = 1e-10,
123
+ tol_reset: bool = False,
124
+ gtol: float | None = 1e-10,
99
125
  params_beta: float | None = None,
100
126
  grads_beta: float | None = None,
101
127
  update_freq: int = 1,
102
- scale_second: bool = True,
128
+ scale_second: bool = False,
103
129
  inner: Chainable | None = None,
104
130
  ):
105
131
  defaults = dict(
106
- history_size=history_size, tol=tol,
132
+ history_size=history_size, tol=tol, gtol=gtol,
107
133
  params_beta=params_beta, grads_beta=grads_beta,
108
- update_freq=update_freq, scale_second=scale_second
134
+ update_freq=update_freq, scale_second=scale_second,
135
+ tol_reset=tol_reset,
109
136
  )
110
- super().__init__(defaults)
137
+ super().__init__(defaults, uses_grad=False, inner=inner)
111
138
 
112
139
  self.global_state['s_history'] = deque(maxlen=history_size)
113
140
  self.global_state['y_history'] = deque(maxlen=history_size)
114
141
 
115
- if inner is not None:
116
- self.set_child('inner', inner)
117
-
118
142
  def reset(self):
119
143
  self.state.clear()
120
144
  self.global_state['step'] = 0
121
145
  self.global_state['s_history'].clear()
122
146
  self.global_state['y_history'].clear()
123
147
 
148
+ def reset_for_online(self):
149
+ super().reset_for_online()
150
+ self.clear_state_keys('prev_l_params', 'prev_l_grad')
151
+ self.global_state.pop('step', None)
124
152
 
125
153
  @torch.no_grad
126
- def step(self, vars: Vars):
127
- params = as_tensorlist(vars.params)
128
- update = as_tensorlist(vars.get_update())
154
+ def update_tensors(self, tensors, params, grads, loss, states, settings):
155
+ params = as_tensorlist(params)
156
+ update = as_tensorlist(tensors)
129
157
  step = self.global_state.get('step', 0)
130
158
  self.global_state['step'] = step + 1
131
159
 
132
160
  s_history: deque[TensorList] = self.global_state['s_history']
133
161
  y_history: deque[TensorList] = self.global_state['y_history']
134
162
 
135
- settings = self.settings[params[0]]
136
- tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
137
-
138
- params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta', params=params) # type: ignore
139
- l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
163
+ setting = settings[0]
164
+ update_freq = itemgetter('update_freq')(setting)
140
165
 
141
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
166
+ params_beta, grads_beta = unpack_dicts(settings, 'params_beta', 'grads_beta')
167
+ l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
168
+ prev_l_params, prev_l_grad = unpack_states(states, tensors, 'prev_l_params', 'prev_l_grad', cls=TensorList)
142
169
 
143
- y_k = None
170
+ s = None
171
+ y = None
144
172
  if step != 0:
145
173
  if step % update_freq == 0:
146
- s_k = l_params - prev_l_params
147
- y_k = l_update - prev_l_grad
174
+ s = l_params - prev_l_params
175
+ y = l_update - prev_l_grad
148
176
 
149
- s_history.append(s_k)
150
- y_history.append(y_k)
177
+ s_history.append(s)
178
+ y_history.append(y)
151
179
 
152
180
  prev_l_params.copy_(l_params)
153
181
  prev_l_grad.copy_(l_update)
154
182
 
155
- if 'inner' in self.children:
156
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
183
+ # store for apply
184
+ self.global_state['s'] = s
185
+ self.global_state['y'] = y
186
+
187
+ @torch.no_grad
188
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
189
+ tensors = as_tensorlist(tensors)
190
+ s = self.global_state.pop('s')
191
+ y = self.global_state.pop('y')
157
192
 
158
- # tolerance on gradient difference to avoid exploding after converging
193
+ setting = settings[0]
194
+ tol = setting['tol']
195
+ gtol = setting['gtol']
196
+ tol_reset = setting['tol_reset']
197
+
198
+ # tolerance on parameter difference to avoid exploding after converging
199
+ if tol is not None:
200
+ if s is not None and s.abs().global_max() <= tol:
201
+ if tol_reset: self.reset()
202
+ return safe_scaling_(TensorList(tensors))
203
+
204
+ # tolerance on gradient difference to avoid exploding when there is no curvature
159
205
  if tol is not None:
160
- if y_k is not None and y_k.abs().global_max() <= tol:
161
- vars.update = update
162
- return vars
206
+ if y is not None and y.abs().global_max() <= gtol:
207
+ return safe_scaling_(TensorList(tensors))
163
208
 
209
+ # precondition
164
210
  dir = lsr1_(
165
- tensors_=update,
166
- s_history=s_history,
167
- y_history=y_history,
168
- step=step,
169
- scale_second=scale_second,
211
+ tensors_=tensors,
212
+ s_history=self.global_state['s_history'],
213
+ y_history=self.global_state['y_history'],
214
+ step=self.global_state.get('step', 1),
215
+ scale_second=setting['scale_second'],
170
216
  )
171
217
 
172
- vars.update = dir
173
-
174
- return vars
218
+ return dir