torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,15 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
- from ...core import Chainable, Transform, apply
6
- from ...utils import TensorList, as_tensorlist
6
+ from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
7
+ from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
7
8
 
8
9
 
9
10
  class ConguateGradientBase(Transform, ABC):
10
11
  """all CGs are the same except beta calculation"""
11
- def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
12
+ def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
12
13
  if defaults is None: defaults = {}
13
14
  defaults['reset_interval'] = reset_interval
14
15
  defaults['clip_beta'] = clip_beta
@@ -25,12 +26,12 @@ class ConguateGradientBase(Transform, ABC):
25
26
  """returns beta"""
26
27
 
27
28
  @torch.no_grad
28
- def transform(self, tensors, params, grads, vars):
29
+ def apply(self, tensors, params, grads, loss, states, settings):
29
30
  tensors = as_tensorlist(tensors)
30
31
  params = as_tensorlist(params)
31
32
 
32
33
  step = self.global_state.get('step', 0)
33
- prev_dir, prev_grads = self.get_state('prev_dir', 'prev_grad', params=params, cls=TensorList)
34
+ prev_dir, prev_grads = unpack_states(states, tensors, 'prev_dir', 'prev_grad', cls=TensorList)
34
35
 
35
36
  # initialize on first step
36
37
  if step == 0:
@@ -42,12 +43,12 @@ class ConguateGradientBase(Transform, ABC):
42
43
 
43
44
  # get beta
44
45
  beta = self.get_beta(params, tensors, prev_grads, prev_dir)
45
- if self.settings[params[0]]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
+ if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
47
  prev_grads.copy_(tensors)
47
48
 
48
49
  # inner step
49
50
  if 'inner' in self.children:
50
- tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
51
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
51
52
 
52
53
  # calculate new direction with beta
53
54
  dir = tensors.add_(prev_dir.mul_(beta))
@@ -55,7 +56,8 @@ class ConguateGradientBase(Transform, ABC):
55
56
 
56
57
  # resetting
57
58
  self.global_state['step'] = step + 1
58
- reset_interval = self.settings[params[0]]['reset_interval']
59
+ reset_interval = settings[0]['reset_interval']
60
+ if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
59
61
  if reset_interval is not None and (step+1) % reset_interval == 0:
60
62
  self.reset()
61
63
 
@@ -64,7 +66,7 @@ class ConguateGradientBase(Transform, ABC):
64
66
  # ------------------------------- Polak-Ribière ------------------------------ #
65
67
  def polak_ribiere_beta(g: TensorList, prev_g: TensorList):
66
68
  denom = prev_g.dot(prev_g)
67
- if denom == 0: return 0
69
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
68
70
  return g.dot(g - prev_g) / denom
69
71
 
70
72
  class PolakRibiere(ConguateGradientBase):
@@ -76,13 +78,13 @@ class PolakRibiere(ConguateGradientBase):
76
78
  return polak_ribiere_beta(g, prev_g)
77
79
 
78
80
  # ------------------------------ Fletcher–Reeves ----------------------------- #
79
- def fletcher_reeves_beta(gg, prev_gg):
80
- if prev_gg == 0: return 0
81
+ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
82
+ if prev_gg.abs() <= torch.finfo(gg.dtype).eps: return 0
81
83
  return gg / prev_gg
82
84
 
83
85
  class FletcherReeves(ConguateGradientBase):
84
86
  """Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
85
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
87
+ def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
86
88
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
87
89
 
88
90
  def initialize(self, p, g):
@@ -98,13 +100,13 @@ class FletcherReeves(ConguateGradientBase):
98
100
  def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
99
101
  grad_diff = g - prev_g
100
102
  denom = prev_d.dot(grad_diff)
101
- if denom == 0: return 0
103
+ if denom.abs() < torch.finfo(g[0].dtype).eps: return 0
102
104
  return (g.dot(grad_diff) / denom).neg()
103
105
 
104
106
 
105
107
  class HestenesStiefel(ConguateGradientBase):
106
108
  """Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
107
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
109
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
108
110
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
109
111
 
110
112
  def get_beta(self, p, g, prev_g, prev_d):
@@ -114,12 +116,12 @@ class HestenesStiefel(ConguateGradientBase):
114
116
  # --------------------------------- Dai–Yuan --------------------------------- #
115
117
  def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
116
118
  denom = prev_d.dot(g - prev_g)
117
- if denom == 0: return 0
119
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
118
120
  return (g.dot(g) / denom).neg()
119
121
 
120
122
  class DaiYuan(ConguateGradientBase):
121
123
  """Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
122
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
124
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
123
125
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
124
126
 
125
127
  def get_beta(self, p, g, prev_g, prev_d):
@@ -129,12 +131,12 @@ class DaiYuan(ConguateGradientBase):
129
131
  # -------------------------------- Liu-Storey -------------------------------- #
130
132
  def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
131
133
  denom = prev_g.dot(prev_d)
132
- if denom == 0: return 0
134
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
133
135
  return g.dot(g - prev_g) / denom
134
136
 
135
137
  class LiuStorey(ConguateGradientBase):
136
138
  """Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
137
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
139
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
138
140
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
139
141
 
140
142
  def get_beta(self, p, g, prev_g, prev_d):
@@ -151,20 +153,20 @@ class ConjugateDescent(Transform):
151
153
 
152
154
 
153
155
  @torch.no_grad
154
- def transform(self, tensors, params, grads, vars):
156
+ def apply(self, tensors, params, grads, loss, states, settings):
155
157
  g = as_tensorlist(tensors)
156
158
 
157
- prev_d = self.get_state('prev_dir', params=params, cls=TensorList, init = torch.zeros_like)
159
+ prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
158
160
  if 'denom' not in self.global_state:
159
161
  self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
162
 
161
163
  prev_gd = self.global_state.get('prev_gd', 0)
162
- if prev_gd == 0: beta = 0
164
+ if abs(prev_gd) <= torch.finfo(g[0].dtype).eps: beta = 0
163
165
  else: beta = g.dot(g) / prev_gd
164
166
 
165
167
  # inner step
166
168
  if 'inner' in self.children:
167
- g = as_tensorlist(apply(self.children['inner'], g, params, grads, vars))
169
+ g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
168
170
 
169
171
  dir = g.add_(prev_d.mul_(beta))
170
172
  prev_d.copy_(dir)
@@ -176,7 +178,7 @@ class ConjugateDescent(Transform):
176
178
  def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
177
179
  g_diff = g - prev_g
178
180
  denom = prev_d.dot(g_diff)
179
- if denom == 0: return 0
181
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
180
182
 
181
183
  term1 = 1/denom
182
184
  # term2
@@ -187,7 +189,7 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
187
189
  class HagerZhang(ConguateGradientBase):
188
190
  """Hager-Zhang nonlinear conjugate gradient method,
189
191
  This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
190
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
192
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
191
193
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
192
194
 
193
195
  def get_beta(self, p, g, prev_g, prev_d):
@@ -198,7 +200,7 @@ class HagerZhang(ConguateGradientBase):
198
200
  def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
199
201
  grad_diff = g - prev_g
200
202
  denom = prev_d.dot(grad_diff)
201
- if denom == 0: return 0
203
+ if denom.abs() <= torch.finfo(g[0].dtype).eps: return 0
202
204
 
203
205
  # Dai-Yuan
204
206
  dy_beta = (g.dot(g) / denom).neg().clamp(min=0)
@@ -211,8 +213,56 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
211
213
  class HybridHS_DY(ConguateGradientBase):
212
214
  """HS-DY hybrid conjugate gradient method.
213
215
  This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
214
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
216
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
215
217
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
216
218
 
217
219
  def get_beta(self, p, g, prev_g, prev_d):
218
220
  return hs_dy_beta(g, prev_d, prev_g)
221
+
222
+
223
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor, tol: float):
224
+ Hy = H @ y
225
+ denom = y.dot(Hy)
226
+ if denom.abs() < tol: return H
227
+ H -= (H @ y.outer(y) @ H) / denom
228
+ return H
229
+
230
+ class ProjectedGradientMethod(TensorwiseTransform):
231
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
232
+
233
+ (This is not the same as projected gradient descent)
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ tol: float = 1e-10,
239
+ reset_interval: int | None = None,
240
+ update_freq: int = 1,
241
+ scale_first: bool = False,
242
+ concat_params: bool = True,
243
+ inner: Chainable | None = None,
244
+ ):
245
+ defaults = dict(reset_interval=reset_interval, tol=tol)
246
+ super().__init__(defaults, uses_grad=False, scale_first=scale_first, concat_params=concat_params, update_freq=update_freq, inner=inner)
247
+
248
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
249
+ step = state.get('step', 0)
250
+ state['step'] = step + 1
251
+ reset_interval = settings['reset_interval']
252
+ if reset_interval is None: reset_interval = tensor.numel() + 1 # as recommended
253
+
254
+ if ("H" not in state) or (step % reset_interval == 0):
255
+ state["H"] = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype)
256
+ state['g_prev'] = tensor.clone()
257
+ return
258
+
259
+ H = state['H']
260
+ g_prev = state['g_prev']
261
+ state['g_prev'] = tensor.clone()
262
+ y = (tensor - g_prev).ravel()
263
+
264
+ projected_gradient_(H, y, settings['tol'])
265
+
266
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
267
+ H = state['H']
268
+ return (H @ tensor.view(-1)).view_as(tensor)
@@ -4,7 +4,7 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
- from ....core import Chainable, Module, Transform, Vars, apply, maybe_chain
7
+ from ....core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
8
  from ....utils import NumberList, TensorList, as_tensorlist
9
9
 
10
10
 
@@ -28,7 +28,7 @@ def _adaptive_damping(
28
28
 
29
29
  def lbfgs(
30
30
  tensors_: TensorList,
31
- vars: Vars,
31
+ var: Var,
32
32
  s_history: deque[TensorList],
33
33
  y_history: deque[TensorList],
34
34
  sy_history: deque[torch.Tensor],
@@ -60,7 +60,7 @@ def lbfgs(
60
60
  z = q * (ys_k / (y_k.dot(y_k)))
61
61
 
62
62
  if z_tfm is not None:
63
- z = TensorList(apply(z_tfm, tensors=z, params=vars.params, grads=vars.grad, vars=vars))
63
+ z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
64
64
 
65
65
  # 2nd loop
66
66
  for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
@@ -73,28 +73,28 @@ def lbfgs(
73
73
  def _apply_tfms_into_history(
74
74
  self: Module,
75
75
  params: list[torch.Tensor],
76
- vars: Vars,
76
+ var: Var,
77
77
  update: list[torch.Tensor],
78
78
  ):
79
79
  if 'params_history_tfm' in self.children:
80
- params = apply(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
80
+ params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
81
81
 
82
82
  if 'grad_history_tfm' in self.children:
83
- update = apply(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=vars.grad, vars=vars)
83
+ update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
84
84
 
85
85
  return params, update
86
86
 
87
87
  def _apply_tfms_into_precond(
88
88
  self: Module,
89
89
  params: list[torch.Tensor],
90
- vars: Vars,
90
+ var: Var,
91
91
  update: list[torch.Tensor],
92
92
  ):
93
93
  if 'params_precond_tfm' in self.children:
94
- params = apply(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
94
+ params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
95
95
 
96
96
  if 'grad_precond_tfm' in self.children:
97
- update = apply(self.children['grad_precond_tfm'], tensors=update, params=params, grads=vars.grad, vars=vars)
97
+ update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
98
98
 
99
99
  return params, update
100
100
 
@@ -165,9 +165,9 @@ class ModularLBFGS(Module):
165
165
  self.global_state['sy_history'].clear()
166
166
 
167
167
  @torch.no_grad
168
- def step(self, vars):
169
- params = as_tensorlist(vars.params)
170
- update = as_tensorlist(vars.get_update())
168
+ def step(self, var):
169
+ params = as_tensorlist(var.params)
170
+ update = as_tensorlist(var.get_update())
171
171
  step = self.global_state.get('step', 0)
172
172
  self.global_state['step'] = step + 1
173
173
 
@@ -186,11 +186,11 @@ class ModularLBFGS(Module):
186
186
  params_h, update_h = _apply_tfms_into_history(
187
187
  self,
188
188
  params=params,
189
- vars=vars,
189
+ var=var,
190
190
  update=update,
191
191
  )
192
192
 
193
- prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h', params=params, cls=TensorList)
193
+ prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
194
194
 
195
195
  # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
196
196
  if step == 0:
@@ -217,16 +217,16 @@ class ModularLBFGS(Module):
217
217
  # step with inner module before applying preconditioner
218
218
  if 'update_precond_tfm' in self.children:
219
219
  update_precond_tfm = self.children['update_precond_tfm']
220
- inner_vars = update_precond_tfm.step(vars.clone(clone_update=True))
221
- vars.update_attrs_from_clone_(inner_vars)
222
- tensors = inner_vars.update
220
+ inner_var = update_precond_tfm.step(var.clone(clone_update=True))
221
+ var.update_attrs_from_clone_(inner_var)
222
+ tensors = inner_var.update
223
223
  assert tensors is not None
224
224
  else:
225
225
  tensors = update.clone()
226
226
 
227
227
  # transforms into preconditioner
228
- params_p, update_p = _apply_tfms_into_precond(self, params=params, vars=vars, update=update)
229
- prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p', params=params, cls=TensorList)
228
+ params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
229
+ prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
230
230
 
231
231
  if step == 0:
232
232
  s_k_p = None; y_k_p = None; ys_k_p = None
@@ -245,13 +245,13 @@ class ModularLBFGS(Module):
245
245
  # tolerance on gradient difference to avoid exploding after converging
246
246
  if tol is not None:
247
247
  if y_k_p is not None and y_k_p.abs().global_max() <= tol:
248
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
249
- return vars
248
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
249
+ return var
250
250
 
251
251
  # precondition
252
252
  dir = lbfgs(
253
253
  tensors_=as_tensorlist(tensors),
254
- vars=vars,
254
+ var=var,
255
255
  s_history=s_history,
256
256
  y_history=y_history,
257
257
  sy_history=sy_history,
@@ -260,7 +260,7 @@ class ModularLBFGS(Module):
260
260
  z_tfm=self.children.get('z_tfm', None),
261
261
  )
262
262
 
263
- vars.update = dir
263
+ var.update = dir
264
264
 
265
- return vars
265
+ return var
266
266
 
@@ -2,7 +2,7 @@ from collections import deque
2
2
  from operator import itemgetter
3
3
  import torch
4
4
 
5
- from ...core import Transform, Chainable, Module, Vars, apply
5
+ from ...core import Transform, Chainable, Module, Var, apply_transform
6
6
  from ...utils import TensorList, as_tensorlist, NumberList
7
7
 
8
8
 
@@ -38,9 +38,9 @@ def lbfgs(
38
38
  if len(s_history) == 0 or y_k is None or ys_k is None:
39
39
 
40
40
  # initial step size guess modified from pytorch L-BFGS
41
- scale = 1 / tensors_.abs().global_sum()
42
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
43
- return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
41
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
42
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
43
+ return tensors_.mul_(scale_factor)
44
44
 
45
45
  else:
46
46
  # 1st loop
@@ -154,9 +154,9 @@ class LBFGS(Module):
154
154
  self.global_state['sy_history'].clear()
155
155
 
156
156
  @torch.no_grad
157
- def step(self, vars):
158
- params = as_tensorlist(vars.params)
159
- update = as_tensorlist(vars.get_update())
157
+ def step(self, var):
158
+ params = as_tensorlist(var.params)
159
+ update = as_tensorlist(var.get_update())
160
160
  step = self.global_state.get('step', 0)
161
161
  self.global_state['step'] = step + 1
162
162
 
@@ -167,10 +167,10 @@ class LBFGS(Module):
167
167
 
168
168
  tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
169
169
  'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
170
- params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params)
170
+ params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
171
171
 
172
172
  l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
173
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
173
+ prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
174
174
 
175
175
  # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
176
176
  if step == 0:
@@ -196,19 +196,19 @@ class LBFGS(Module):
196
196
 
197
197
  # step with inner module before applying preconditioner
198
198
  if self.children:
199
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
199
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
200
200
 
201
201
  # tolerance on gradient difference to avoid exploding after converging
202
202
  if tol is not None:
203
203
  if y_k is not None and y_k.abs().global_max() <= tol:
204
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
204
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
205
205
  if tol_reset: self.reset()
206
- return vars
206
+ return var
207
207
 
208
208
  # lerp initial H^-1 @ q guess
209
209
  z_ema = None
210
210
  if z_beta is not None:
211
- z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
211
+ z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
212
212
 
213
213
  # precondition
214
214
  dir = lbfgs(
@@ -223,7 +223,7 @@ class LBFGS(Module):
223
223
  step=step
224
224
  )
225
225
 
226
- vars.update = dir
226
+ var.update = dir
227
227
 
228
- return vars
228
+ return var
229
229
 
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, Transform, Vars, apply
6
+ from ...core import Chainable, Module, Transform, Var, apply_transform
7
7
  from ...utils import NumberList, TensorList, as_tensorlist
8
8
 
9
9
  from .lbfgs import _lerp_params_update_
@@ -17,9 +17,9 @@ def lsr1_(
17
17
  ):
18
18
  if step == 0 or not s_history:
19
19
  # initial step size guess from pytorch
20
- scale = 1 / tensors_.abs().global_sum()
21
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
22
- return tensors_.mul_(min(1.0, scale)) # pyright: ignore[reportArgumentType]
20
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
21
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
22
+ return tensors_.mul_(scale_factor)
23
23
 
24
24
  m = len(s_history)
25
25
 
@@ -65,9 +65,10 @@ def lsr1_(
65
65
  Hx.add_(w_k, alpha=w_k.dot(tensors_) / wy) # pyright:ignore[reportArgumentType]
66
66
 
67
67
  if scale_second and step == 1:
68
- scale = 1 / tensors_.abs().global_sum()
69
- if scale < 1e-5: scale = 1 / tensors_.abs().mean()
70
- Hx.mul_(min(1.0, scale)) # pyright:ignore[reportArgumentType]
68
+ scale_factor = 1 / TensorList(tensors_).abs().global_sum().clip(min=1)
69
+ scale_factor = scale_factor.clip(min=torch.finfo(tensors_[0].dtype).eps)
70
+ Hx.mul_(scale_factor)
71
+
71
72
  return Hx
72
73
 
73
74
 
@@ -122,9 +123,9 @@ class LSR1(Module):
122
123
 
123
124
 
124
125
  @torch.no_grad
125
- def step(self, vars: Vars):
126
- params = as_tensorlist(vars.params)
127
- update = as_tensorlist(vars.get_update())
126
+ def step(self, var: Var):
127
+ params = as_tensorlist(var.params)
128
+ update = as_tensorlist(var.get_update())
128
129
  step = self.global_state.get('step', 0)
129
130
  self.global_state['step'] = step + 1
130
131
 
@@ -134,10 +135,10 @@ class LSR1(Module):
134
135
  settings = self.settings[params[0]]
135
136
  tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
136
137
 
137
- params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta', params=params) # type: ignore
138
+ params_beta, grads_beta_ = self.get_settings(params, 'params_beta', 'grads_beta') # type: ignore
138
139
  l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
139
140
 
140
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
141
+ prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
141
142
 
142
143
  y_k = None
143
144
  if step != 0:
@@ -152,13 +153,13 @@ class LSR1(Module):
152
153
  prev_l_grad.copy_(l_update)
153
154
 
154
155
  if 'inner' in self.children:
155
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
156
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
156
157
 
157
158
  # tolerance on gradient difference to avoid exploding after converging
158
159
  if tol is not None:
159
160
  if y_k is not None and y_k.abs().global_max() <= tol:
160
- vars.update = update
161
- return vars
161
+ var.update = update
162
+ return var
162
163
 
163
164
  dir = lsr1_(
164
165
  tensors_=update,
@@ -168,6 +169,6 @@ class LSR1(Module):
168
169
  scale_second=scale_second,
169
170
  )
170
171
 
171
- vars.update = dir
172
+ var.update = dir
172
173
 
173
- return vars
174
+ return var
@@ -5,17 +5,17 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, Transform, Vars, apply
8
+ from ...core import Chainable, Module, Transform, Var, apply_transform
9
9
  from ...utils import NumberList, TensorList, as_tensorlist
10
10
  from .lbfgs import _adaptive_damping, lbfgs
11
11
 
12
12
 
13
13
  @torch.no_grad
14
- def _store_sk_yk_after_step_hook(optimizer, vars: Vars, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
- assert vars.closure is not None
16
- with torch.enable_grad(): vars.closure()
17
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in vars.params]
18
- s_k = vars.params - prev_params
14
+ def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
+ assert var.closure is not None
16
+ with torch.enable_grad(): var.closure()
17
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
18
+ s_k = var.params - prev_params
19
19
  y_k = grad - prev_grad
20
20
  ys_k = s_k.dot(y_k)
21
21
 
@@ -95,11 +95,11 @@ class OnlineLBFGS(Module):
95
95
  self.global_state['sy_history'].clear()
96
96
 
97
97
  @torch.no_grad
98
- def step(self, vars):
99
- assert vars.closure is not None
98
+ def step(self, var):
99
+ assert var.closure is not None
100
100
 
101
- params = as_tensorlist(vars.params)
102
- update = as_tensorlist(vars.get_update())
101
+ params = as_tensorlist(var.params)
102
+ update = as_tensorlist(var.get_update())
103
103
  step = self.global_state.get('step', 0)
104
104
  self.global_state['step'] = step + 1
105
105
 
@@ -113,7 +113,7 @@ class OnlineLBFGS(Module):
113
113
 
114
114
  # sample gradient at previous params with current mini-batch
115
115
  if sample_grads == 'before':
116
- prev_params = self.get_state('prev_params', params=params, cls=TensorList)
116
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
117
117
  if step == 0:
118
118
  s_k = None; y_k = None; ys_k = None
119
119
  else:
@@ -121,7 +121,7 @@ class OnlineLBFGS(Module):
121
121
 
122
122
  current_params = params.clone()
123
123
  params.set_(prev_params)
124
- with torch.enable_grad(): vars.closure()
124
+ with torch.enable_grad(): var.closure()
125
125
  y_k = update - params.grad
126
126
  ys_k = s_k.dot(y_k)
127
127
  params.set_(current_params)
@@ -146,7 +146,7 @@ class OnlineLBFGS(Module):
146
146
  ys_k = s_k.dot(y_k)
147
147
 
148
148
  # this will run after params are updated by Modular after running all future modules
149
- vars.post_step_hooks.append(
149
+ var.post_step_hooks.append(
150
150
  partial(
151
151
  _store_sk_yk_after_step_hook,
152
152
  prev_params=params.clone(),
@@ -164,18 +164,18 @@ class OnlineLBFGS(Module):
164
164
 
165
165
  # step with inner module before applying preconditioner
166
166
  if self.children:
167
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
167
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
168
168
 
169
169
  # tolerance on gradient difference to avoid exploding after converging
170
170
  if tol is not None:
171
171
  if y_k is not None and y_k.abs().global_max() <= tol:
172
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
173
- return vars
172
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
173
+ return var
174
174
 
175
175
  # lerp initial H^-1 @ q guess
176
176
  z_ema = None
177
177
  if z_beta is not None:
178
- z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
178
+ z_ema = self.get_state(params, 'z_ema', cls=TensorList)
179
179
 
180
180
  # precondition
181
181
  dir = lbfgs(
@@ -190,7 +190,7 @@ class OnlineLBFGS(Module):
190
190
  step=step
191
191
  )
192
192
 
193
- vars.update = dir
193
+ var.update = dir
194
194
 
195
- return vars
195
+ return var
196
196