torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,15 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Literal
2
3
 
3
4
  import torch
4
5
 
5
- from ...core import Chainable, Transform, apply
6
- from ...utils import TensorList, as_tensorlist
6
+ from ...core import Chainable, TensorwiseTransform, Transform, apply_transform
7
+ from ...utils import TensorList, as_tensorlist, unpack_dicts, unpack_states
7
8
 
8
9
 
9
10
  class ConguateGradientBase(Transform, ABC):
10
11
  """all CGs are the same except beta calculation"""
11
- def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None = None, inner: Chainable | None = None):
12
+ def __init__(self, defaults = None, clip_beta: bool = False, reset_interval: int | None | Literal['auto'] = None, inner: Chainable | None = None):
12
13
  if defaults is None: defaults = {}
13
14
  defaults['reset_interval'] = reset_interval
14
15
  defaults['clip_beta'] = clip_beta
@@ -25,12 +26,12 @@ class ConguateGradientBase(Transform, ABC):
25
26
  """returns beta"""
26
27
 
27
28
  @torch.no_grad
28
- def transform(self, tensors, params, grads, vars):
29
+ def apply(self, tensors, params, grads, loss, states, settings):
29
30
  tensors = as_tensorlist(tensors)
30
31
  params = as_tensorlist(params)
31
32
 
32
33
  step = self.global_state.get('step', 0)
33
- prev_dir, prev_grads = self.get_state('prev_dir', 'prev_grad', params=params, cls=TensorList)
34
+ prev_dir, prev_grads = unpack_states(states, tensors, 'prev_dir', 'prev_grad', cls=TensorList)
34
35
 
35
36
  # initialize on first step
36
37
  if step == 0:
@@ -42,12 +43,12 @@ class ConguateGradientBase(Transform, ABC):
42
43
 
43
44
  # get beta
44
45
  beta = self.get_beta(params, tensors, prev_grads, prev_dir)
45
- if self.settings[params[0]]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
+ if settings[0]['clip_beta']: beta = max(0, beta) # pyright:ignore[reportArgumentType]
46
47
  prev_grads.copy_(tensors)
47
48
 
48
49
  # inner step
49
50
  if 'inner' in self.children:
50
- tensors = as_tensorlist(apply(self.children['inner'], tensors, params, grads, vars))
51
+ tensors = as_tensorlist(apply_transform(self.children['inner'], tensors, params, grads))
51
52
 
52
53
  # calculate new direction with beta
53
54
  dir = tensors.add_(prev_dir.mul_(beta))
@@ -55,7 +56,8 @@ class ConguateGradientBase(Transform, ABC):
55
56
 
56
57
  # resetting
57
58
  self.global_state['step'] = step + 1
58
- reset_interval = self.settings[params[0]]['reset_interval']
59
+ reset_interval = settings[0]['reset_interval']
60
+ if reset_interval == 'auto': reset_interval = tensors.global_numel() + 1
59
61
  if reset_interval is not None and (step+1) % reset_interval == 0:
60
62
  self.reset()
61
63
 
@@ -82,7 +84,7 @@ def fletcher_reeves_beta(gg: torch.Tensor, prev_gg: torch.Tensor):
82
84
 
83
85
  class FletcherReeves(ConguateGradientBase):
84
86
  """Fletcher–Reeves nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
85
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
87
+ def __init__(self, reset_interval: int | None | Literal['auto'] = 'auto', clip_beta=False, inner: Chainable | None = None):
86
88
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
87
89
 
88
90
  def initialize(self, p, g):
@@ -104,7 +106,7 @@ def hestenes_stiefel_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
104
106
 
105
107
  class HestenesStiefel(ConguateGradientBase):
106
108
  """Hestenes–Stiefel nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
107
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
109
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
108
110
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
109
111
 
110
112
  def get_beta(self, p, g, prev_g, prev_d):
@@ -119,7 +121,7 @@ def dai_yuan_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
119
121
 
120
122
  class DaiYuan(ConguateGradientBase):
121
123
  """Dai–Yuan nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
122
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
124
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
123
125
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
124
126
 
125
127
  def get_beta(self, p, g, prev_g, prev_d):
@@ -134,7 +136,7 @@ def liu_storey_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList, ):
134
136
 
135
137
  class LiuStorey(ConguateGradientBase):
136
138
  """Liu-Storey nonlinear conjugate gradient method. This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
137
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
139
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
138
140
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
139
141
 
140
142
  def get_beta(self, p, g, prev_g, prev_d):
@@ -151,10 +153,10 @@ class ConjugateDescent(Transform):
151
153
 
152
154
 
153
155
  @torch.no_grad
154
- def transform(self, tensors, params, grads, vars):
156
+ def apply(self, tensors, params, grads, loss, states, settings):
155
157
  g = as_tensorlist(tensors)
156
158
 
157
- prev_d = self.get_state('prev_dir', params=params, cls=TensorList, init = torch.zeros_like)
159
+ prev_d = unpack_states(states, tensors, 'prev_dir', cls=TensorList, init=torch.zeros_like)
158
160
  if 'denom' not in self.global_state:
159
161
  self.global_state['denom'] = torch.tensor(0.).to(g[0])
160
162
 
@@ -164,7 +166,7 @@ class ConjugateDescent(Transform):
164
166
 
165
167
  # inner step
166
168
  if 'inner' in self.children:
167
- g = as_tensorlist(apply(self.children['inner'], g, params, grads, vars))
169
+ g = as_tensorlist(apply_transform(self.children['inner'], g, params, grads))
168
170
 
169
171
  dir = g.add_(prev_d.mul_(beta))
170
172
  prev_d.copy_(dir)
@@ -187,7 +189,7 @@ def hager_zhang_beta(g:TensorList, prev_d:TensorList, prev_g:TensorList,):
187
189
  class HagerZhang(ConguateGradientBase):
188
190
  """Hager-Zhang nonlinear conjugate gradient method,
189
191
  This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
190
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
192
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
191
193
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
192
194
 
193
195
  def get_beta(self, p, g, prev_g, prev_d):
@@ -211,8 +213,56 @@ def hs_dy_beta(g: TensorList, prev_d: TensorList,prev_g: TensorList):
211
213
  class HybridHS_DY(ConguateGradientBase):
212
214
  """HS-DY hybrid conjugate gradient method.
213
215
  This requires step size to be determined via a line search, so put a line search like :code:`StrongWolfe` after this."""
214
- def __init__(self, reset_interval: int | None = None, clip_beta=False, inner: Chainable | None = None):
216
+ def __init__(self, reset_interval: int | None | Literal['auto'] = None, clip_beta=False, inner: Chainable | None = None):
215
217
  super().__init__(clip_beta=clip_beta, reset_interval=reset_interval, inner=inner)
216
218
 
217
219
  def get_beta(self, p, g, prev_g, prev_d):
218
220
  return hs_dy_beta(g, prev_d, prev_g)
221
+
222
+
223
+ def projected_gradient_(H:torch.Tensor, y:torch.Tensor, tol: float):
224
+ Hy = H @ y
225
+ denom = y.dot(Hy)
226
+ if denom.abs() < tol: return H
227
+ H -= (H @ y.outer(y) @ H) / denom
228
+ return H
229
+
230
+ class ProjectedGradientMethod(TensorwiseTransform):
231
+ """Pearson, J. D. (1969). Variable metric methods of minimisation. The Computer Journal, 12(2), 171–178. doi:10.1093/comjnl/12.2.171.
232
+
233
+ (This is not the same as projected gradient descent)
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ tol: float = 1e-10,
239
+ reset_interval: int | None = None,
240
+ update_freq: int = 1,
241
+ scale_first: bool = False,
242
+ concat_params: bool = True,
243
+ inner: Chainable | None = None,
244
+ ):
245
+ defaults = dict(reset_interval=reset_interval, tol=tol)
246
+ super().__init__(defaults, uses_grad=False, scale_first=scale_first, concat_params=concat_params, update_freq=update_freq, inner=inner)
247
+
248
+ def update_tensor(self, tensor, param, grad, loss, state, settings):
249
+ step = state.get('step', 0)
250
+ state['step'] = step + 1
251
+ reset_interval = settings['reset_interval']
252
+ if reset_interval is None: reset_interval = tensor.numel() + 1 # as recommended
253
+
254
+ if ("H" not in state) or (step % reset_interval == 0):
255
+ state["H"] = torch.eye(tensor.numel(), device=tensor.device, dtype=tensor.dtype)
256
+ state['g_prev'] = tensor.clone()
257
+ return
258
+
259
+ H = state['H']
260
+ g_prev = state['g_prev']
261
+ state['g_prev'] = tensor.clone()
262
+ y = (tensor - g_prev).ravel()
263
+
264
+ projected_gradient_(H, y, settings['tol'])
265
+
266
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
267
+ H = state['H']
268
+ return (H @ tensor.view(-1)).view_as(tensor)
@@ -4,7 +4,7 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
- from ....core import Chainable, Module, Transform, Vars, apply, maybe_chain
7
+ from ....core import Chainable, Module, Transform, Var, apply_transform, maybe_chain
8
8
  from ....utils import NumberList, TensorList, as_tensorlist
9
9
 
10
10
 
@@ -28,7 +28,7 @@ def _adaptive_damping(
28
28
 
29
29
  def lbfgs(
30
30
  tensors_: TensorList,
31
- vars: Vars,
31
+ var: Var,
32
32
  s_history: deque[TensorList],
33
33
  y_history: deque[TensorList],
34
34
  sy_history: deque[torch.Tensor],
@@ -60,7 +60,7 @@ def lbfgs(
60
60
  z = q * (ys_k / (y_k.dot(y_k)))
61
61
 
62
62
  if z_tfm is not None:
63
- z = TensorList(apply(z_tfm, tensors=z, params=vars.params, grads=vars.grad, vars=vars))
63
+ z = TensorList(apply_transform(z_tfm, tensors=z, params=var.params, grads=var.grad, var=var))
64
64
 
65
65
  # 2nd loop
66
66
  for s_i, y_i, ys_i, alpha_i in zip(s_history, y_history, sy_history, reversed(alpha_list)):
@@ -73,28 +73,28 @@ def lbfgs(
73
73
  def _apply_tfms_into_history(
74
74
  self: Module,
75
75
  params: list[torch.Tensor],
76
- vars: Vars,
76
+ var: Var,
77
77
  update: list[torch.Tensor],
78
78
  ):
79
79
  if 'params_history_tfm' in self.children:
80
- params = apply(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
80
+ params = apply_transform(self.children['params_history_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
81
81
 
82
82
  if 'grad_history_tfm' in self.children:
83
- update = apply(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=vars.grad, vars=vars)
83
+ update = apply_transform(self.children['grad_history_tfm'], tensors=as_tensorlist(update).clone(), params=params, grads=var.grad, var=var)
84
84
 
85
85
  return params, update
86
86
 
87
87
  def _apply_tfms_into_precond(
88
88
  self: Module,
89
89
  params: list[torch.Tensor],
90
- vars: Vars,
90
+ var: Var,
91
91
  update: list[torch.Tensor],
92
92
  ):
93
93
  if 'params_precond_tfm' in self.children:
94
- params = apply(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=vars.grad, vars=vars)
94
+ params = apply_transform(self.children['params_precond_tfm'], tensors=as_tensorlist(params).clone(), params=params, grads=var.grad, var=var)
95
95
 
96
96
  if 'grad_precond_tfm' in self.children:
97
- update = apply(self.children['grad_precond_tfm'], tensors=update, params=params, grads=vars.grad, vars=vars)
97
+ update = apply_transform(self.children['grad_precond_tfm'], tensors=update, params=params, grads=var.grad, var=var)
98
98
 
99
99
  return params, update
100
100
 
@@ -165,9 +165,9 @@ class ModularLBFGS(Module):
165
165
  self.global_state['sy_history'].clear()
166
166
 
167
167
  @torch.no_grad
168
- def step(self, vars):
169
- params = as_tensorlist(vars.params)
170
- update = as_tensorlist(vars.get_update())
168
+ def step(self, var):
169
+ params = as_tensorlist(var.params)
170
+ update = as_tensorlist(var.get_update())
171
171
  step = self.global_state.get('step', 0)
172
172
  self.global_state['step'] = step + 1
173
173
 
@@ -186,11 +186,11 @@ class ModularLBFGS(Module):
186
186
  params_h, update_h = _apply_tfms_into_history(
187
187
  self,
188
188
  params=params,
189
- vars=vars,
189
+ var=var,
190
190
  update=update,
191
191
  )
192
192
 
193
- prev_params_h, prev_grad_h = self.get_state('prev_params_h', 'prev_grad_h', params=params, cls=TensorList)
193
+ prev_params_h, prev_grad_h = self.get_state(params, 'prev_params_h', 'prev_grad_h', cls=TensorList)
194
194
 
195
195
  # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
196
196
  if step == 0:
@@ -217,16 +217,16 @@ class ModularLBFGS(Module):
217
217
  # step with inner module before applying preconditioner
218
218
  if 'update_precond_tfm' in self.children:
219
219
  update_precond_tfm = self.children['update_precond_tfm']
220
- inner_vars = update_precond_tfm.step(vars.clone(clone_update=True))
221
- vars.update_attrs_from_clone_(inner_vars)
222
- tensors = inner_vars.update
220
+ inner_var = update_precond_tfm.step(var.clone(clone_update=True))
221
+ var.update_attrs_from_clone_(inner_var)
222
+ tensors = inner_var.update
223
223
  assert tensors is not None
224
224
  else:
225
225
  tensors = update.clone()
226
226
 
227
227
  # transforms into preconditioner
228
- params_p, update_p = _apply_tfms_into_precond(self, params=params, vars=vars, update=update)
229
- prev_params_p, prev_grad_p = self.get_state('prev_params_p', 'prev_grad_p', params=params, cls=TensorList)
228
+ params_p, update_p = _apply_tfms_into_precond(self, params=params, var=var, update=update)
229
+ prev_params_p, prev_grad_p = self.get_state(params, 'prev_params_p', 'prev_grad_p', cls=TensorList)
230
230
 
231
231
  if step == 0:
232
232
  s_k_p = None; y_k_p = None; ys_k_p = None
@@ -245,13 +245,13 @@ class ModularLBFGS(Module):
245
245
  # tolerance on gradient difference to avoid exploding after converging
246
246
  if tol is not None:
247
247
  if y_k_p is not None and y_k_p.abs().global_max() <= tol:
248
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
249
- return vars
248
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
249
+ return var
250
250
 
251
251
  # precondition
252
252
  dir = lbfgs(
253
253
  tensors_=as_tensorlist(tensors),
254
- vars=vars,
254
+ var=var,
255
255
  s_history=s_history,
256
256
  y_history=y_history,
257
257
  sy_history=sy_history,
@@ -260,7 +260,7 @@ class ModularLBFGS(Module):
260
260
  z_tfm=self.children.get('z_tfm', None),
261
261
  )
262
262
 
263
- vars.update = dir
263
+ var.update = dir
264
264
 
265
- return vars
265
+ return var
266
266
 
@@ -2,7 +2,7 @@ from collections import deque
2
2
  from operator import itemgetter
3
3
  import torch
4
4
 
5
- from ...core import Transform, Chainable, Module, Vars, apply
5
+ from ...core import Transform, Chainable, Module, Var, apply_transform
6
6
  from ...utils import TensorList, as_tensorlist, NumberList
7
7
 
8
8
 
@@ -154,9 +154,9 @@ class LBFGS(Module):
154
154
  self.global_state['sy_history'].clear()
155
155
 
156
156
  @torch.no_grad
157
- def step(self, vars):
158
- params = as_tensorlist(vars.params)
159
- update = as_tensorlist(vars.get_update())
157
+ def step(self, var):
158
+ params = as_tensorlist(var.params)
159
+ update = as_tensorlist(var.get_update())
160
160
  step = self.global_state.get('step', 0)
161
161
  self.global_state['step'] = step + 1
162
162
 
@@ -167,10 +167,10 @@ class LBFGS(Module):
167
167
 
168
168
  tol, damping, init_damping, eigval_bounds, update_freq, z_beta, tol_reset = itemgetter(
169
169
  'tol', 'damping', 'init_damping', 'eigval_bounds', 'update_freq', 'z_beta', 'tol_reset')(self.settings[params[0]])
170
- params_beta, grads_beta = self.get_settings('params_beta', 'grads_beta', params=params)
170
+ params_beta, grads_beta = self.get_settings(params, 'params_beta', 'grads_beta')
171
171
 
172
172
  l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta)
173
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
173
+ prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
174
174
 
175
175
  # 1st step - there are no previous params and grads, `lbfgs` will do normalized SGD step
176
176
  if step == 0:
@@ -196,19 +196,19 @@ class LBFGS(Module):
196
196
 
197
197
  # step with inner module before applying preconditioner
198
198
  if self.children:
199
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
199
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
200
200
 
201
201
  # tolerance on gradient difference to avoid exploding after converging
202
202
  if tol is not None:
203
203
  if y_k is not None and y_k.abs().global_max() <= tol:
204
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
204
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
205
205
  if tol_reset: self.reset()
206
- return vars
206
+ return var
207
207
 
208
208
  # lerp initial H^-1 @ q guess
209
209
  z_ema = None
210
210
  if z_beta is not None:
211
- z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
211
+ z_ema = self.get_state(var.params, 'z_ema', cls=TensorList)
212
212
 
213
213
  # precondition
214
214
  dir = lbfgs(
@@ -223,7 +223,7 @@ class LBFGS(Module):
223
223
  step=step
224
224
  )
225
225
 
226
- vars.update = dir
226
+ var.update = dir
227
227
 
228
- return vars
228
+ return var
229
229
 
@@ -3,7 +3,7 @@ from operator import itemgetter
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Chainable, Module, Transform, Vars, apply
6
+ from ...core import Chainable, Module, Transform, Var, apply_transform
7
7
  from ...utils import NumberList, TensorList, as_tensorlist
8
8
 
9
9
  from .lbfgs import _lerp_params_update_
@@ -123,9 +123,9 @@ class LSR1(Module):
123
123
 
124
124
 
125
125
  @torch.no_grad
126
- def step(self, vars: Vars):
127
- params = as_tensorlist(vars.params)
128
- update = as_tensorlist(vars.get_update())
126
+ def step(self, var: Var):
127
+ params = as_tensorlist(var.params)
128
+ update = as_tensorlist(var.get_update())
129
129
  step = self.global_state.get('step', 0)
130
130
  self.global_state['step'] = step + 1
131
131
 
@@ -135,10 +135,10 @@ class LSR1(Module):
135
135
  settings = self.settings[params[0]]
136
136
  tol, update_freq, scale_second = itemgetter('tol', 'update_freq', 'scale_second')(settings)
137
137
 
138
- params_beta, grads_beta_ = self.get_settings('params_beta', 'grads_beta', params=params) # type: ignore
138
+ params_beta, grads_beta_ = self.get_settings(params, 'params_beta', 'grads_beta') # type: ignore
139
139
  l_params, l_update = _lerp_params_update_(self, params, update, params_beta, grads_beta_)
140
140
 
141
- prev_l_params, prev_l_grad = self.get_state('prev_l_params', 'prev_l_grad', params=params, cls=TensorList)
141
+ prev_l_params, prev_l_grad = self.get_state(params, 'prev_l_params', 'prev_l_grad', cls=TensorList)
142
142
 
143
143
  y_k = None
144
144
  if step != 0:
@@ -153,13 +153,13 @@ class LSR1(Module):
153
153
  prev_l_grad.copy_(l_update)
154
154
 
155
155
  if 'inner' in self.children:
156
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
156
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
157
157
 
158
158
  # tolerance on gradient difference to avoid exploding after converging
159
159
  if tol is not None:
160
160
  if y_k is not None and y_k.abs().global_max() <= tol:
161
- vars.update = update
162
- return vars
161
+ var.update = update
162
+ return var
163
163
 
164
164
  dir = lsr1_(
165
165
  tensors_=update,
@@ -169,6 +169,6 @@ class LSR1(Module):
169
169
  scale_second=scale_second,
170
170
  )
171
171
 
172
- vars.update = dir
172
+ var.update = dir
173
173
 
174
- return vars
174
+ return var
@@ -5,17 +5,17 @@ from typing import Literal
5
5
 
6
6
  import torch
7
7
 
8
- from ...core import Chainable, Module, Transform, Vars, apply
8
+ from ...core import Chainable, Module, Transform, Var, apply_transform
9
9
  from ...utils import NumberList, TensorList, as_tensorlist
10
10
  from .lbfgs import _adaptive_damping, lbfgs
11
11
 
12
12
 
13
13
  @torch.no_grad
14
- def _store_sk_yk_after_step_hook(optimizer, vars: Vars, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
- assert vars.closure is not None
16
- with torch.enable_grad(): vars.closure()
17
- grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in vars.params]
18
- s_k = vars.params - prev_params
14
+ def _store_sk_yk_after_step_hook(optimizer, var: Var, prev_params: TensorList, prev_grad: TensorList, damping, init_damping, eigval_bounds, s_history: deque[TensorList], y_history: deque[TensorList], sy_history: deque[torch.Tensor]):
15
+ assert var.closure is not None
16
+ with torch.enable_grad(): var.closure()
17
+ grad = [p.grad if p.grad is not None else torch.zeros_like(p) for p in var.params]
18
+ s_k = var.params - prev_params
19
19
  y_k = grad - prev_grad
20
20
  ys_k = s_k.dot(y_k)
21
21
 
@@ -95,11 +95,11 @@ class OnlineLBFGS(Module):
95
95
  self.global_state['sy_history'].clear()
96
96
 
97
97
  @torch.no_grad
98
- def step(self, vars):
99
- assert vars.closure is not None
98
+ def step(self, var):
99
+ assert var.closure is not None
100
100
 
101
- params = as_tensorlist(vars.params)
102
- update = as_tensorlist(vars.get_update())
101
+ params = as_tensorlist(var.params)
102
+ update = as_tensorlist(var.get_update())
103
103
  step = self.global_state.get('step', 0)
104
104
  self.global_state['step'] = step + 1
105
105
 
@@ -113,7 +113,7 @@ class OnlineLBFGS(Module):
113
113
 
114
114
  # sample gradient at previous params with current mini-batch
115
115
  if sample_grads == 'before':
116
- prev_params = self.get_state('prev_params', params=params, cls=TensorList)
116
+ prev_params = self.get_state(params, 'prev_params', cls=TensorList)
117
117
  if step == 0:
118
118
  s_k = None; y_k = None; ys_k = None
119
119
  else:
@@ -121,7 +121,7 @@ class OnlineLBFGS(Module):
121
121
 
122
122
  current_params = params.clone()
123
123
  params.set_(prev_params)
124
- with torch.enable_grad(): vars.closure()
124
+ with torch.enable_grad(): var.closure()
125
125
  y_k = update - params.grad
126
126
  ys_k = s_k.dot(y_k)
127
127
  params.set_(current_params)
@@ -146,7 +146,7 @@ class OnlineLBFGS(Module):
146
146
  ys_k = s_k.dot(y_k)
147
147
 
148
148
  # this will run after params are updated by Modular after running all future modules
149
- vars.post_step_hooks.append(
149
+ var.post_step_hooks.append(
150
150
  partial(
151
151
  _store_sk_yk_after_step_hook,
152
152
  prev_params=params.clone(),
@@ -164,18 +164,18 @@ class OnlineLBFGS(Module):
164
164
 
165
165
  # step with inner module before applying preconditioner
166
166
  if self.children:
167
- update = TensorList(apply(self.children['inner'], tensors=update, params=params, grads=vars.grad, vars=vars))
167
+ update = TensorList(apply_transform(self.children['inner'], tensors=update, params=params, grads=var.grad, var=var))
168
168
 
169
169
  # tolerance on gradient difference to avoid exploding after converging
170
170
  if tol is not None:
171
171
  if y_k is not None and y_k.abs().global_max() <= tol:
172
- vars.update = update # may have been updated by inner module, probably makes sense to use it here?
173
- return vars
172
+ var.update = update # may have been updated by inner module, probably makes sense to use it here?
173
+ return var
174
174
 
175
175
  # lerp initial H^-1 @ q guess
176
176
  z_ema = None
177
177
  if z_beta is not None:
178
- z_ema = self.get_state('z_ema', params=vars.params, cls=TensorList)
178
+ z_ema = self.get_state(params, 'z_ema', cls=TensorList)
179
179
 
180
180
  # precondition
181
181
  dir = lbfgs(
@@ -190,7 +190,7 @@ class OnlineLBFGS(Module):
190
190
  step=step
191
191
  )
192
192
 
193
- vars.update = dir
193
+ var.update = dir
194
194
 
195
- return vars
195
+ return var
196
196