torchzero 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +54 -21
  2. tests/test_tensorlist.py +2 -2
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +19 -129
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +12 -12
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +67 -17
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +12 -12
  78. torchzero/modules/quasi_newton/lsr1.py +11 -11
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +254 -47
  81. torchzero/modules/second_order/newton.py +32 -20
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +21 -21
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.9.dist-info/RECORD +0 -131
  107. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.9.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,8 @@ from functools import partial
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform
7
- from ...utils import NumberList, TensorList
6
+ from ...core import Module, Target, Transform, apply_transform, Chainable
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
8
  from ..functional import (
9
9
  debias, debiased_step_size,
10
10
  ema_,
@@ -27,24 +27,25 @@ def adam_(
27
27
  pow: float = 2,
28
28
  debiased: bool = True,
29
29
  max_exp_avg_sq_: TensorList | None = None,
30
- params_: TensorList | None = None,
30
+
31
+ # inner args
32
+ inner: Module | None = None,
33
+ params: list[torch.Tensor] | None = None,
34
+ grads: list[torch.Tensor] | None = None,
31
35
  ):
32
36
  """Returns new tensors or updates params in-place."""
33
- exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
34
-
35
37
  sqrt_exp_avg_sq = sqrt_ema_sq_(tensors, exp_avg_sq_=exp_avg_sq_, beta=beta2, max_exp_avg_sq_=max_exp_avg_sq_,
36
38
  debiased=False,step=step,pow=pow)
37
39
 
38
- if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
40
+ if inner is not None:
41
+ assert params is not None
42
+ tensors = TensorList(apply_transform(inner, tensors, params=params, grads=grads))
39
43
 
40
- # params is None, return update
41
- if params_ is None: return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
42
-
43
- # update params in-place
44
- params_.addcdiv_(exp_avg_, sqrt_exp_avg_sq.add_(eps), -alpha)
45
- return None
44
+ exp_avg_ = ema_(tensors, exp_avg_=exp_avg_, beta=beta1, dampening=0,lerp=True)
45
+ if debiased: alpha = debiased_step_size(step, beta1=beta1, beta2=beta2, pow=pow, alpha=alpha)
46
+ return (exp_avg_ / sqrt_exp_avg_sq.add_(eps)).lazy_mul(alpha)
46
47
 
47
- class Adam(Module):
48
+ class Adam(Transform):
48
49
  """Adam. Divides gradient EMA by EMA of gradient squares with debiased step size. This implementation is slightly different from
49
50
  pytorch in that debiasing is applied after adding epsilon.
50
51
 
@@ -66,36 +67,29 @@ class Adam(Module):
66
67
  alpha: float = 1.,
67
68
  pow: float = 2,
68
69
  debiased: bool = True,
70
+ inner: Chainable | None = None
69
71
  ):
70
72
  defaults=dict(beta1=beta1,beta2=beta2,eps=eps,alpha=alpha,amsgrad=amsgrad,pow=pow,debiased=debiased)
71
- super().__init__(defaults)
72
- self.getter = itemgetter('amsgrad','pow','debiased')
73
+ super().__init__(defaults, uses_grad=False)
74
+
75
+ if inner is not None: self.set_child('inner', inner)
73
76
 
74
77
  @torch.no_grad
75
- def step(self, vars):
78
+ def apply(self, tensors, params, grads, loss, states, settings):
76
79
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
77
80
 
78
- beta1,beta2,eps,alpha=self.get_settings('beta1','beta2','eps','alpha', params=vars.params, cls=NumberList)
79
- amsgrad,pow,debiased = self.getter(self.settings[vars.params[0]])
81
+ beta1,beta2,eps,alpha=unpack_dicts(settings, 'beta1','beta2','eps','alpha', cls=NumberList)
82
+ amsgrad,pow,debiased = itemgetter('amsgrad','pow','debiased')(settings[0])
80
83
 
81
84
  if amsgrad:
82
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg','exp_avg_sq','max_exp_avg_sq', params=vars.params, cls=TensorList)
85
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
83
86
  else:
84
- exp_avg, exp_avg_sq = self.get_state('exp_avg','exp_avg_sq', params=vars.params, cls=TensorList)
87
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
85
88
  max_exp_avg_sq = None
86
89
 
87
- # if this is last module, update parameters in-place with slightly more efficient addcdiv_
88
- if vars.is_last:
89
- if vars.last_module_lrs is not None: alpha = alpha * vars.last_module_lrs
90
- passed_params = TensorList(vars.params)
91
- vars.stop = True
92
- vars.skip_update = True
93
-
94
- else:
95
- passed_params = None
96
90
 
97
- vars.update = adam_(
98
- tensors=TensorList(vars.get_update()),
91
+ return adam_(
92
+ tensors=TensorList(tensors),
99
93
  exp_avg_=exp_avg,
100
94
  exp_avg_sq_=exp_avg_sq,
101
95
  alpha=alpha,
@@ -106,7 +100,10 @@ class Adam(Module):
106
100
  pow=pow,
107
101
  debiased=debiased,
108
102
  max_exp_avg_sq_=max_exp_avg_sq,
109
- params_=passed_params,
110
- )
111
103
 
112
- return vars
104
+ # inner args
105
+ inner=self.children.get("inner", None),
106
+ params=params,
107
+ grads=grads,
108
+
109
+ )
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
 
3
3
  from ...core import Module, Target, Transform
4
- from ...utils import NumberList, TensorList
4
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
5
5
 
6
6
 
7
7
  def lion_(tensors: TensorList, exp_avg_: TensorList, beta1, beta2,):
@@ -28,8 +28,8 @@ class Lion(Transform):
28
28
  super().__init__(defaults, uses_grad=False)
29
29
 
30
30
  @torch.no_grad
31
- def transform(self, tensors, params, grads, vars):
32
- beta1, beta2 = self.get_settings('beta1', 'beta2', params = params, cls=NumberList)
33
- exp_avg = self.get_state('ema', params=params, cls=TensorList)
31
+ def apply(self, tensors, params, grads, loss, states, settings):
32
+ beta1, beta2 = unpack_dicts(settings, 'beta1', 'beta2', cls=NumberList)
33
+ exp_avg = unpack_states(states, tensors, 'ema', cls=TensorList)
34
34
  return lion_(TensorList(tensors),exp_avg,beta1,beta2)
35
35
 
@@ -164,7 +164,7 @@ class Orthogonalize(TensorwiseTransform):
164
164
  method (str, optional):
165
165
  Newton-Schulz is very fast, SVD is extremely slow but can be slighly more precise.
166
166
  target (str, optional):
167
- what to set on vars.
167
+ what to set on var.
168
168
  """
169
169
  def __init__(self, ns_steps=5, adjust_lr=False, dual_norm_correction=False,
170
170
  method: Literal['newton-schulz', 'svd'] = 'newton-schulz', target:Target='update'):
@@ -172,9 +172,9 @@ class Orthogonalize(TensorwiseTransform):
172
172
  super().__init__(uses_grad=False, defaults=defaults, target=target)
173
173
 
174
174
  @torch.no_grad
175
- def transform(self, tensor, param, grad, vars):
175
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
176
176
  orthogonalize, ns_steps, dual_norm_correction, adjust_lr, method = itemgetter(
177
- 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(self.settings[param])
177
+ 'orthogonalize', 'ns_steps', 'dual_norm_correction', 'adjust_lr', 'method')(settings)
178
178
 
179
179
  if not orthogonalize: return tensor
180
180
 
@@ -199,7 +199,7 @@ class DualNormCorrection(TensorwiseTransform):
199
199
  def __init__(self, target: Target='update'):
200
200
  super().__init__({}, uses_grad=True, target=target)
201
201
 
202
- def transform(self, tensor, param, grad, vars):
202
+ def apply_tensor(self, tensor, param, grad, loss, state, settings):
203
203
  assert grad is not None
204
204
  if (tensor.ndim >= 2) and (tensor.size(0) > 1) and (tensor.size(1) > 1):
205
205
  return _dual_norm_correction(tensor, grad, batch_first=False)
@@ -213,8 +213,8 @@ class MuonAdjustLR(Transform):
213
213
  defaults = dict(alpha=alpha)
214
214
  super().__init__(defaults=defaults, uses_grad=False, target=target)
215
215
 
216
- def transform(self, tensors, params, grads, vars):
217
- alphas = self.get_settings('alpha', params=params)
216
+ def apply(self, tensors, params, grads, loss, states, settings):
217
+ alphas = [s['alpha'] for s in settings]
218
218
  tensors_alphas = [(t, adjust_lr_for_muon(a, t.shape)) for t, a in zip(tensors, alphas) if _is_at_least_2d(t)]
219
219
  tensors = [i[0] for i in tensors_alphas]
220
220
  a = [i[1] for i in alphas]
@@ -30,16 +30,15 @@ class OrthoGrad(Transform):
30
30
  Args:
31
31
  eps (float, optional): epsilon added to the denominator for numerical stability (default: 1e-30)
32
32
  renormalize (bool, optional): whether to graft projected gradient to original gradient norm. Defaults to True.
33
- target (Target, optional): what to set on vars. Defaults to 'update'.
33
+ target (Target, optional): what to set on var. Defaults to 'update'.
34
34
  """
35
35
  def __init__(self, eps: float = 1e-8, renormalize=True, target: Target = 'update'):
36
36
  defaults = dict(eps=eps, renormalize=renormalize)
37
37
  super().__init__(defaults, uses_grad=False, target=target)
38
38
 
39
- def transform(self, tensors, params, grads, vars):
40
- settings = self.settings[params[0]]
41
- eps = settings['eps']
42
- renormalize = settings['renormalize']
39
+ def apply(self, tensors, params, grads, loss, states, settings):
40
+ eps = settings[0]['eps']
41
+ renormalize = settings[0]['renormalize']
43
42
 
44
43
  params = as_tensorlist(params)
45
44
  target = as_tensorlist(tensors)
@@ -3,8 +3,8 @@ from typing import Literal
3
3
 
4
4
  import torch
5
5
 
6
- from ...core import Module, Target, Transform, Chainable, Vars, apply
7
- from ...utils import NumberList, TensorList
6
+ from ...core import Module, Target, Transform, Chainable, Var, apply_transform
7
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
8
8
  from ..functional import sqrt_centered_ema_sq_, sqrt_ema_sq_
9
9
 
10
10
 
@@ -23,7 +23,6 @@ def rmsprop_(
23
23
  inner: Module | None = None,
24
24
  params: list[torch.Tensor] | None = None,
25
25
  grads: list[torch.Tensor] | None = None,
26
- vars: Vars | None = None,
27
26
  ):
28
27
  """returns `tensors_`"""
29
28
  if exp_avg_ is not None:
@@ -36,7 +35,7 @@ def rmsprop_(
36
35
 
37
36
  if inner is not None:
38
37
  assert params is not None
39
- tensors_ = TensorList(apply(inner, tensors_, params=params, grads=grads, vars=vars))
38
+ tensors_ = TensorList(apply_transform(inner, tensors_, params=params, grads=grads))
40
39
 
41
40
  return tensors_.div_(sqrt_exp_avg_sq.add_(eps))
42
41
 
@@ -66,21 +65,20 @@ class RMSprop(Transform):
66
65
  ):
67
66
  defaults = dict(smoothing=smoothing,eps=eps,centered=centered,debiased=debiased,amsgrad=amsgrad,pow=pow,init=init)
68
67
  super().__init__(defaults=defaults, uses_grad=False)
69
- self.current_step = 0
68
+
70
69
  if inner is not None:
71
70
  self.set_child('inner', inner)
72
71
 
73
- def transform(self, tensors, params, grads, vars):
74
- self.current_step += 1
75
-
76
- smoothing,eps = self.get_settings('smoothing', 'eps', params=params, cls=NumberList)
77
- centered,debiased,amsgrad,pow,init = itemgetter('centered','debiased','amsgrad','pow','init')(self.settings[params[0]])
72
+ def apply(self, tensors, params, grads, loss, states, settings):
73
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
74
+ smoothing, eps = unpack_dicts(settings, 'smoothing', 'eps', cls=NumberList)
75
+ centered, debiased, amsgrad, pow, init = itemgetter('centered','debiased','amsgrad','pow','init')(settings[0])
78
76
 
79
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
80
- exp_avg = self.get_state('exp_avg', params=params, cls=TensorList) if centered else None
81
- max_exp_avg_sq = self.get_state('max_exp_avg_sq', params=params, cls=TensorList) if amsgrad else None
77
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
78
+ exp_avg = unpack_states(states, tensors, 'exp_avg', cls=TensorList) if centered else None
79
+ max_exp_avg_sq = unpack_states(states, tensors, 'max_exp_avg_sq', cls=TensorList) if amsgrad else None
82
80
 
83
- if init == 'update' and self.current_step == 1:
81
+ if init == 'update' and step == 1:
84
82
  exp_avg_sq.set_([t**2 for t in tensors])
85
83
  if exp_avg is not None: exp_avg.set_([t.clone() for t in tensors])
86
84
 
@@ -90,7 +88,7 @@ class RMSprop(Transform):
90
88
  smoothing=smoothing,
91
89
  eps=eps,
92
90
  debiased=debiased,
93
- step=self.current_step,
91
+ step=step,
94
92
  exp_avg_=exp_avg,
95
93
  max_exp_avg_sq_=max_exp_avg_sq,
96
94
  pow=pow,
@@ -99,5 +97,4 @@ class RMSprop(Transform):
99
97
  inner=self.children.get("inner", None),
100
98
  params=params,
101
99
  grads=grads,
102
- vars=vars,
103
100
  )
@@ -2,7 +2,7 @@
2
2
  import torch
3
3
 
4
4
  from ...core import Module, Target, Transform
5
- from ...utils import NumberList, TensorList, as_tensorlist
5
+ from ...utils import NumberList, TensorList, as_tensorlist, unpack_dicts, unpack_states
6
6
 
7
7
 
8
8
  def _bool_ones_like(x):
@@ -161,20 +161,22 @@ class Rprop(Transform):
161
161
  alpha: float = 1,
162
162
  ):
163
163
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, backtrack=backtrack)
164
- self.current_step = 0
165
164
  super().__init__(defaults, uses_grad=False)
166
165
 
167
166
  @torch.no_grad
168
- def transform(self, tensors, params, grads, vars):
169
- nplus, nminus, lb, ub, alpha = self.get_settings('nplus', 'nminus', 'lb', 'ub', 'alpha', params=params, cls=NumberList)
170
- prev, allowed, magnitudes = self.get_state(
167
+ def apply(self, tensors, params, grads, loss, states, settings):
168
+ step = self.global_state.get('step', 0)
169
+ self.global_state['step'] = step + 1
170
+
171
+ nplus, nminus, lb, ub, alpha = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', 'alpha', cls=NumberList)
172
+ prev, allowed, magnitudes = unpack_states(
173
+ states, tensors,
171
174
  'prev','allowed','magnitudes',
172
- params=params,
173
175
  init=[torch.zeros_like, _bool_ones_like, torch.zeros_like],
174
176
  cls = TensorList,
175
177
  )
176
178
 
177
- target = rprop_(
179
+ tensors = rprop_(
178
180
  tensors_ = as_tensorlist(tensors),
179
181
  prev_ = prev,
180
182
  allowed_ = allowed,
@@ -184,12 +186,11 @@ class Rprop(Transform):
184
186
  lb = lb,
185
187
  ub = ub,
186
188
  alpha = alpha,
187
- backtrack=self.settings[params[0]]['backtrack'],
188
- step=self.current_step,
189
+ backtrack=settings[0]['backtrack'],
190
+ step=step,
189
191
  )
190
192
 
191
- self.current_step += 1
192
- return target
193
+ return tensors
193
194
 
194
195
 
195
196
  class ScaleLRBySignChange(Transform):
@@ -220,23 +221,25 @@ class ScaleLRBySignChange(Transform):
220
221
  ):
221
222
  defaults = dict(nplus=nplus, nminus=nminus, alpha=alpha, lb=lb, ub=ub, use_grad=use_grad)
222
223
  super().__init__(defaults, uses_grad=use_grad, target=target)
223
- self.current_step = 0
224
224
 
225
225
  @torch.no_grad
226
- def transform(self, tensors, params, grads, vars):
227
- target = as_tensorlist(tensors)
228
- use_grad = self.settings[params[0]]['use_grad']
226
+ def apply(self, tensors, params, grads, loss, states, settings):
227
+ step = self.global_state.get('step', 0)
228
+ self.global_state['step'] = step + 1
229
+
230
+ tensors = as_tensorlist(tensors)
231
+ use_grad = settings[0]['use_grad']
229
232
  if use_grad: cur = as_tensorlist(grads)
230
- else: cur = target
233
+ else: cur = tensors
231
234
 
232
- nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
233
- prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
235
+ nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
236
+ prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
234
237
 
235
- if self.current_step == 0:
236
- lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
238
+ if step == 0:
239
+ lrs.set_(tensors.full_like([s['alpha'] for s in settings]))
237
240
 
238
- target = scale_by_sign_change_(
239
- tensors_ = target,
241
+ tensors = scale_by_sign_change_(
242
+ tensors_ = tensors,
240
243
  cur = cur,
241
244
  prev_ = prev,
242
245
  lrs_ = lrs,
@@ -244,10 +247,9 @@ class ScaleLRBySignChange(Transform):
244
247
  nminus = nminus,
245
248
  lb = lb,
246
249
  ub = ub,
247
- step = self.current_step,
250
+ step = step,
248
251
  )
249
- self.current_step += 1
250
- return target
252
+ return tensors
251
253
 
252
254
  class BacktrackOnSignChange(Transform):
253
255
  """Negates or undoes update for parameters where where gradient or update sign changes.
@@ -268,28 +270,28 @@ class BacktrackOnSignChange(Transform):
268
270
  def __init__(self, use_grad = False, backtrack = True, target: Target = 'update'):
269
271
  defaults = dict(use_grad=use_grad, backtrack=backtrack, target=target)
270
272
  super().__init__(defaults, uses_grad=use_grad)
271
- self.current_step = 0
272
273
 
273
274
  @torch.no_grad
274
- def transform(self, tensors, params, grads, vars):
275
- target = as_tensorlist(tensors)
276
- settings = self.settings[params[0]]
277
- use_grad = settings['use_grad']
278
- backtrack = settings['backtrack']
275
+ def apply(self, tensors, params, grads, loss, states, settings):
276
+ step = self.global_state.get('step', 0)
277
+ self.global_state['step'] = step + 1
278
+
279
+ tensors = as_tensorlist(tensors)
280
+ use_grad = settings[0]['use_grad']
281
+ backtrack = settings[0]['backtrack']
279
282
 
280
283
  if use_grad: cur = as_tensorlist(grads)
281
- else: cur = target
284
+ else: cur = tensors
282
285
 
283
- target = backtrack_on_sign_change_(
284
- tensors_ = target,
286
+ tensors = backtrack_on_sign_change_(
287
+ tensors_ = tensors,
285
288
  cur = cur,
286
- prev_ = self.get_state('prev', params=params, cls=TensorList),
289
+ prev_ = unpack_states(states, tensors, 'prev', cls=TensorList),
287
290
  backtrack = backtrack,
288
- step = self.current_step,
291
+ step = step,
289
292
  )
290
293
 
291
- self.current_step += 1
292
- return target
294
+ return tensors
293
295
 
294
296
  class SignConsistencyMask(Transform):
295
297
  """0 if sign changed 1 otherwise"""
@@ -297,10 +299,10 @@ class SignConsistencyMask(Transform):
297
299
  super().__init__({}, uses_grad=False, target = target)
298
300
 
299
301
  @torch.no_grad
300
- def transform(self, tensors, params, grads, vars):
301
- prev = self.get_state('prev', params=params, cls=TensorList)
302
+ def apply(self, tensors, params, grads, loss, states, settings):
303
+ prev = unpack_states(states, tensors, 'prev', cls=TensorList)
302
304
  mask = prev.mul_(tensors).gt_(0)
303
- prev.set_(tensors)
305
+ prev.copy_(tensors)
304
306
  return mask
305
307
 
306
308
 
@@ -317,16 +319,18 @@ class SignConsistencyLRs(Transform):
317
319
  ):
318
320
  defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
319
321
  super().__init__(defaults, uses_grad=False, target = target)
320
- self.current_step = 0
321
322
 
322
323
  @torch.no_grad
323
- def transform(self, tensors, params, grads, vars):
324
+ def apply(self, tensors, params, grads, loss, states, settings):
325
+ step = self.global_state.get('step', 0)
326
+ self.global_state['step'] = step + 1
327
+
324
328
  target = as_tensorlist(tensors)
325
- nplus, nminus, lb, ub = self.get_settings('nplus', 'nminus', 'lb', 'ub', params=params, cls=NumberList)
326
- prev, lrs = self.get_state('prev', 'lrs', params=params, cls=TensorList)
329
+ nplus, nminus, lb, ub = unpack_dicts(settings, 'nplus', 'nminus', 'lb', 'ub', cls=NumberList)
330
+ prev, lrs = unpack_states(states, tensors, 'prev', 'lrs', cls=TensorList)
327
331
 
328
- if self.current_step == 0:
329
- lrs.set_(target.full_like(self.get_settings('alpha', params=params)))
332
+ if step == 0:
333
+ lrs.set_(target.full_like([s['alpha'] for s in settings]))
330
334
 
331
335
  target = sign_consistency_lrs_(
332
336
  tensors = target,
@@ -336,7 +340,6 @@ class SignConsistencyLRs(Transform):
336
340
  nminus = nminus,
337
341
  lb = lb,
338
342
  ub = ub,
339
- step = self.current_step,
343
+ step = step,
340
344
  )
341
- self.current_step += 1
342
345
  return target.clone()
@@ -4,7 +4,7 @@ from functools import partial
4
4
  import numpy as np
5
5
  import torch
6
6
 
7
- from ...core import Chainable, Transform, apply
7
+ from ...core import Chainable, Transform, apply_transform
8
8
  from ...utils.linalg import matrix_power_eigh
9
9
  from ...utils import set_storage_
10
10
 
@@ -106,7 +106,6 @@ class Shampoo(Transform):
106
106
  self,
107
107
  decay: float | None = None,
108
108
  beta: float | None = None,
109
- reg: float = 1e-6,
110
109
  update_freq: int = 10,
111
110
  exp_override: int | None = None,
112
111
  merge_small: bool = True,
@@ -115,25 +114,24 @@ class Shampoo(Transform):
115
114
  adagrad_eps: float = 1e-8,
116
115
  inner: Chainable | None = None,
117
116
  ):
118
- defaults = dict(decay=decay, beta=beta, reg=reg, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
117
+ defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
119
118
  super().__init__(defaults, uses_grad=False)
120
119
 
121
120
  if inner is not None:
122
121
  self.set_child('inner', inner)
123
122
 
124
- def transform(self, tensors, params, grads, vars):
125
- merged_target = [] # target with merged dims
123
+ def apply(self, tensors, params, grads, loss, states, settings):
124
+ merged_tensors = [] # target with merged dims
126
125
 
127
126
  # update preconditioners
128
- for i,(p,t) in enumerate(zip(params, tensors)):
129
- state = self.state[p]
130
- settings = self.settings[p]
131
- beta, reg, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
132
- 'beta', 'reg', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(settings)
127
+ for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
128
+ beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
129
+ 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
133
130
 
134
131
  if merge_small:
135
132
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
136
- merged_target.append(t)
133
+
134
+ merged_tensors.append(t)
137
135
 
138
136
  # initialize accumulators and preconditioners for each dim on 1st step
139
137
  if 'accumulators' not in state:
@@ -167,22 +165,18 @@ class Shampoo(Transform):
167
165
 
168
166
  # inner step
169
167
  if 'inner' in self.children:
170
- tensors = apply(self.children['inner'], tensors, params=params, grads=grads, vars=vars)
168
+ tensors = apply_transform(self.children['inner'], tensors, params=params, grads=grads)
171
169
 
172
170
  # have to merge small dims again
173
- merged_target = [] # target with merged dims
174
- for i,(p,t) in enumerate(zip(params, tensors)):
175
- state = self.state[p]
176
- settings = self.settings[p]
177
- if settings['merge_small']:
178
- t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, settings['max_dim'])
179
- merged_target.append(t)
171
+ merged_tensors = [] # target with merged dims
172
+ for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
173
+ if setting['merge_small']:
174
+ t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, setting['max_dim'])
175
+ merged_tensors.append(t)
180
176
 
181
177
  # precondition
182
- for i, (p, t) in enumerate(zip(params, merged_target)):
183
- state = self.state[p]
184
- settings = self.settings[p]
185
- decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(settings)
178
+ for i,(t,state, setting) in enumerate(zip(merged_tensors, states, settings)):
179
+ decay, merge_small, adagrad_eps= itemgetter('decay', 'merge_small', 'adagrad_eps')(setting)
186
180
 
187
181
  if 'diagonal_accumulator' in state:
188
182
  tensors[i] = apply_diagonal_(t, state['diagonal_accumulator'], decay=decay, eps=adagrad_eps)
@@ -2,7 +2,7 @@ from operator import itemgetter
2
2
 
3
3
  import torch
4
4
 
5
- from ...core import Chainable, Transform, apply
5
+ from ...core import Chainable, Transform, apply_transform
6
6
  from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
7
 
8
8
  @torch.no_grad
@@ -152,9 +152,8 @@ class SOAP(Transform):
152
152
  epsilon for dividing first momentum by second. Defaults to 1e-8.
153
153
  decay (float | None, optional):
154
154
  Decays covariance matrix accumulators, this may be useful if `shampoo_beta` is None. Defaults to None.
155
- unprojected_exp_avg (bool, optional):
156
- whether to update first momentum in unprojected space. Both true and false work and lead to different
157
- results but True usually works better. Defaults to True.
155
+ alpha (float, optional):
156
+ learning rate. Defaults to 1.
158
157
  bias_correction (bool, optional):
159
158
  enables adam bias correction. Defaults to True.
160
159
  """
@@ -170,7 +169,6 @@ class SOAP(Transform):
170
169
  eps: float = 1e-8,
171
170
  decay: float | None = None,
172
171
  alpha: float = 1,
173
- unprojected_exp_avg: bool = True,
174
172
  bias_correction: bool = True,
175
173
  ):
176
174
  defaults = dict(
@@ -183,21 +181,18 @@ class SOAP(Transform):
183
181
  precondition_1d=precondition_1d,
184
182
  eps=eps,
185
183
  decay=decay,
186
- unprojected_exp_avg=unprojected_exp_avg,
187
184
  bias_correction=bias_correction,
188
185
  alpha=alpha,
189
186
  )
190
187
  super().__init__(defaults, uses_grad=False)
191
188
 
192
189
  @torch.no_grad
193
- def transform(self, tensors, params, grads, vars):
190
+ def apply(self, tensors, params, grads, loss, states, settings):
194
191
  updates = []
195
192
  # update preconditioners
196
- for i,(p,t) in enumerate(zip(params, tensors)):
197
- state = self.state[p]
198
- settings = self.settings[p]
199
- beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps, unprojected_exp_avg,alpha = itemgetter(
200
- 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps', 'unprojected_exp_avg','alpha')(settings)
193
+ for i,(p,t, state, setting) in enumerate(zip(params, tensors, states, settings)):
194
+ beta1, beta2, shampoo_beta, merge_small, max_dim, precondition_1d, eps,alpha = itemgetter(
195
+ 'beta1', 'beta2', 'shampoo_beta', 'merge_small', 'max_dim', 'precondition_1d', 'eps','alpha')(setting)
201
196
 
202
197
  if merge_small:
203
198
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -237,10 +232,7 @@ class SOAP(Transform):
237
232
  exp_avg: torch.Tensor = state["exp_avg"]
238
233
  exp_avg_sq: torch.Tensor = state["exp_avg_sq"]
239
234
 
240
- if unprojected_exp_avg or t_projected is None:
241
- exp_avg.lerp_(t, 1-beta1)
242
- else:
243
- exp_avg.lerp_(t_projected, 1-beta1)
235
+ exp_avg.lerp_(t, 1-beta1)
244
236
 
245
237
  if t_projected is None:
246
238
  exp_avg_sq.mul_(beta2).addcmul_(t, t, value=1-beta2)
@@ -249,7 +241,7 @@ class SOAP(Transform):
249
241
 
250
242
  # project exponential moving averages if they are accumulated unprojected
251
243
  exp_avg_projected = exp_avg
252
- if unprojected_exp_avg and t_projected is not None:
244
+ if t_projected is not None:
253
245
  exp_avg_projected = project(exp_avg, state['Q'])
254
246
 
255
247
  exp_avg_sq_projected = exp_avg_sq
@@ -260,10 +252,11 @@ class SOAP(Transform):
260
252
  # Projecting back the preconditioned (by Adam) exponential moving average of gradients
261
253
  # to the original space
262
254
  update = exp_avg_projected / denom
255
+
263
256
  if t_projected is not None:
264
257
  update = project_back(update, state["Q"])
265
258
 
266
- if settings['bias_correction']:
259
+ if setting['bias_correction']:
267
260
  bias_correction1 = 1.0 - beta1 ** (state["step"]+1)
268
261
  bias_correction2 = 1.0 - beta2 ** (state["step"]+1)
269
262
  update *= ((bias_correction2 ** .5) / bias_correction1) * alpha
@@ -279,7 +272,7 @@ class SOAP(Transform):
279
272
  # Update is done after the gradient step to avoid using current gradients in the projection.
280
273
  if state['GG'] is not None:
281
274
  update_soap_covariances_(t, state['GG'], shampoo_beta)
282
- if state['step'] % settings['precond_freq'] == 0:
275
+ if state['step'] % setting['precond_freq'] == 0:
283
276
  state['Q'], state['exp_avg_sq'] = get_orthogonal_matrix_QR(exp_avg_sq, state['GG'], state['Q'])
284
277
 
285
278
  return updates