torchzero 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. tests/test_opts.py +55 -22
  2. tests/test_tensorlist.py +3 -3
  3. tests/test_vars.py +61 -61
  4. torchzero/core/__init__.py +2 -3
  5. torchzero/core/module.py +49 -49
  6. torchzero/core/transform.py +219 -158
  7. torchzero/modules/__init__.py +1 -0
  8. torchzero/modules/clipping/clipping.py +10 -10
  9. torchzero/modules/clipping/ema_clipping.py +14 -13
  10. torchzero/modules/clipping/growth_clipping.py +16 -18
  11. torchzero/modules/experimental/__init__.py +12 -3
  12. torchzero/modules/experimental/absoap.py +50 -156
  13. torchzero/modules/experimental/adadam.py +15 -14
  14. torchzero/modules/experimental/adamY.py +17 -27
  15. torchzero/modules/experimental/adasoap.py +20 -130
  16. torchzero/modules/experimental/curveball.py +12 -12
  17. torchzero/modules/experimental/diagonal_higher_order_newton.py +225 -0
  18. torchzero/modules/experimental/eigendescent.py +117 -0
  19. torchzero/modules/experimental/etf.py +172 -0
  20. torchzero/modules/experimental/gradmin.py +2 -2
  21. torchzero/modules/experimental/newton_solver.py +11 -11
  22. torchzero/modules/experimental/newtonnewton.py +88 -0
  23. torchzero/modules/experimental/reduce_outward_lr.py +8 -5
  24. torchzero/modules/experimental/soapy.py +19 -146
  25. torchzero/modules/experimental/spectral.py +79 -204
  26. torchzero/modules/experimental/structured_newton.py +111 -0
  27. torchzero/modules/experimental/subspace_preconditioners.py +13 -10
  28. torchzero/modules/experimental/tada.py +38 -0
  29. torchzero/modules/grad_approximation/fdm.py +2 -2
  30. torchzero/modules/grad_approximation/forward_gradient.py +5 -5
  31. torchzero/modules/grad_approximation/grad_approximator.py +21 -21
  32. torchzero/modules/grad_approximation/rfdm.py +28 -15
  33. torchzero/modules/higher_order/__init__.py +1 -0
  34. torchzero/modules/higher_order/higher_order_newton.py +256 -0
  35. torchzero/modules/line_search/backtracking.py +42 -23
  36. torchzero/modules/line_search/line_search.py +40 -40
  37. torchzero/modules/line_search/scipy.py +18 -3
  38. torchzero/modules/line_search/strong_wolfe.py +21 -32
  39. torchzero/modules/line_search/trust_region.py +18 -6
  40. torchzero/modules/lr/__init__.py +1 -1
  41. torchzero/modules/lr/{step_size.py → adaptive.py} +22 -26
  42. torchzero/modules/lr/lr.py +20 -16
  43. torchzero/modules/momentum/averaging.py +25 -10
  44. torchzero/modules/momentum/cautious.py +73 -35
  45. torchzero/modules/momentum/ema.py +92 -41
  46. torchzero/modules/momentum/experimental.py +21 -13
  47. torchzero/modules/momentum/matrix_momentum.py +96 -54
  48. torchzero/modules/momentum/momentum.py +24 -4
  49. torchzero/modules/ops/accumulate.py +51 -21
  50. torchzero/modules/ops/binary.py +36 -36
  51. torchzero/modules/ops/debug.py +7 -7
  52. torchzero/modules/ops/misc.py +128 -129
  53. torchzero/modules/ops/multi.py +19 -19
  54. torchzero/modules/ops/reduce.py +16 -16
  55. torchzero/modules/ops/split.py +26 -26
  56. torchzero/modules/ops/switch.py +4 -4
  57. torchzero/modules/ops/unary.py +20 -20
  58. torchzero/modules/ops/utility.py +37 -37
  59. torchzero/modules/optimizers/adagrad.py +33 -24
  60. torchzero/modules/optimizers/adam.py +31 -34
  61. torchzero/modules/optimizers/lion.py +4 -4
  62. torchzero/modules/optimizers/muon.py +6 -6
  63. torchzero/modules/optimizers/orthograd.py +4 -5
  64. torchzero/modules/optimizers/rmsprop.py +13 -16
  65. torchzero/modules/optimizers/rprop.py +52 -49
  66. torchzero/modules/optimizers/shampoo.py +17 -23
  67. torchzero/modules/optimizers/soap.py +12 -19
  68. torchzero/modules/optimizers/sophia_h.py +13 -13
  69. torchzero/modules/projections/dct.py +4 -4
  70. torchzero/modules/projections/fft.py +6 -6
  71. torchzero/modules/projections/galore.py +1 -1
  72. torchzero/modules/projections/projection.py +57 -57
  73. torchzero/modules/projections/structural.py +17 -17
  74. torchzero/modules/quasi_newton/__init__.py +33 -4
  75. torchzero/modules/quasi_newton/cg.py +76 -26
  76. torchzero/modules/quasi_newton/experimental/modular_lbfgs.py +24 -24
  77. torchzero/modules/quasi_newton/lbfgs.py +15 -15
  78. torchzero/modules/quasi_newton/lsr1.py +18 -17
  79. torchzero/modules/quasi_newton/olbfgs.py +19 -19
  80. torchzero/modules/quasi_newton/quasi_newton.py +257 -48
  81. torchzero/modules/second_order/newton.py +38 -21
  82. torchzero/modules/second_order/newton_cg.py +13 -12
  83. torchzero/modules/second_order/nystrom.py +19 -19
  84. torchzero/modules/smoothing/gaussian.py +21 -21
  85. torchzero/modules/smoothing/laplacian.py +7 -9
  86. torchzero/modules/weight_decay/__init__.py +1 -1
  87. torchzero/modules/weight_decay/weight_decay.py +43 -9
  88. torchzero/modules/wrappers/optim_wrapper.py +11 -11
  89. torchzero/optim/wrappers/directsearch.py +244 -0
  90. torchzero/optim/wrappers/fcmaes.py +97 -0
  91. torchzero/optim/wrappers/mads.py +90 -0
  92. torchzero/optim/wrappers/nevergrad.py +4 -4
  93. torchzero/optim/wrappers/nlopt.py +28 -14
  94. torchzero/optim/wrappers/optuna.py +70 -0
  95. torchzero/optim/wrappers/scipy.py +162 -13
  96. torchzero/utils/__init__.py +2 -6
  97. torchzero/utils/derivatives.py +2 -1
  98. torchzero/utils/optimizer.py +55 -74
  99. torchzero/utils/python_tools.py +17 -4
  100. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/METADATA +14 -14
  101. torchzero-0.3.10.dist-info/RECORD +139 -0
  102. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/WHEEL +1 -1
  103. torchzero/core/preconditioner.py +0 -138
  104. torchzero/modules/experimental/algebraic_newton.py +0 -145
  105. torchzero/modules/experimental/tropical_newton.py +0 -136
  106. torchzero-0.3.8.dist-info/RECORD +0 -130
  107. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/licenses/LICENSE +0 -0
  108. {torchzero-0.3.8.dist-info → torchzero-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ """Cautioning related modules"""
1
2
  from collections import deque
2
3
  from operator import itemgetter
3
4
  from typing import Literal
@@ -5,7 +6,7 @@ from typing import Literal
5
6
  import torch
6
7
 
7
8
  from ...core import Target, Transform, Module, Chainable
8
- from ...utils import NumberList, TensorList
9
+ from ...utils import NumberList, TensorList, unpack_dicts
9
10
 
10
11
 
11
12
  def cautious_(
@@ -64,27 +65,33 @@ class Cautious(Transform):
64
65
  normalize=False,
65
66
  eps=1e-6,
66
67
  mode: Literal["zero", "grad", "backtrack"] = "zero",
67
- target: Target = "update",
68
68
  ):
69
69
  defaults = dict(normalize=normalize, eps=eps, mode=mode)
70
- super().__init__(defaults, uses_grad=True, target=target)
70
+ super().__init__(defaults, uses_grad=True)
71
71
 
72
72
  @torch.no_grad
73
- def transform(self, tensors, params, grads, vars):
73
+ def apply(self, tensors, params, grads, loss, states, settings):
74
74
  assert grads is not None
75
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[params[0]])
75
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
76
76
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
77
77
 
78
78
  class UpdateGradientSignConsistency(Transform):
79
- """1 where signs match 0 otherwise"""
80
- def __init__(self, normalize = False, eps=1e-6, target: Target = 'update'):
79
+ """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
80
+
81
+ Args:
82
+ normalize (bool, optional):
83
+ renormalize update after masking. Defaults to False.
84
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
85
+ """
86
+ def __init__(self, normalize = False, eps=1e-6):
87
+
81
88
  defaults = dict(normalize=normalize, eps=eps)
82
- super().__init__(defaults, uses_grad=True, target=target)
89
+ super().__init__(defaults, uses_grad=True)
83
90
 
84
91
  @torch.no_grad
85
- def transform(self, tensors, params, grads, vars):
92
+ def apply(self, tensors, params, grads, loss, states, settings):
86
93
  assert grads is not None
87
- normalize, eps = itemgetter('normalize', 'eps')(self.settings[params[0]])
94
+ normalize, eps = itemgetter('normalize', 'eps')(settings[0])
88
95
 
89
96
  mask = (TensorList(tensors).mul_(grads)).gt_(0)
90
97
  if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
@@ -92,6 +99,23 @@ class UpdateGradientSignConsistency(Transform):
92
99
  return mask
93
100
 
94
101
  class IntermoduleCautious(Module):
102
+ """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
103
+
104
+ Args:
105
+ main (Chainable): main module or sequence of modules whose update will be cautioned.
106
+ compare (Chainable): modules or sequence of modules to compare the sign to.
107
+ normalize (bool, optional):
108
+ renormalize update after masking. Defaults to False.
109
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
110
+ mode (str, optional):
111
+ what to do with updates with inconsistent signs.
112
+
113
+ "zero" - set them to zero (as in paper)
114
+
115
+ "grad" - set them to the gradient
116
+
117
+ "backtrack" - negate them (same as using update magnitude and gradient sign)
118
+ """
95
119
  def __init__(
96
120
  self,
97
121
  main: Chainable,
@@ -100,6 +124,7 @@ class IntermoduleCautious(Module):
100
124
  eps=1e-6,
101
125
  mode: Literal["zero", "grad", "backtrack"] = "zero",
102
126
  ):
127
+
103
128
  defaults = dict(normalize=normalize, eps=eps, mode=mode)
104
129
  super().__init__(defaults)
105
130
 
@@ -107,40 +132,45 @@ class IntermoduleCautious(Module):
107
132
  self.set_child('compare', compare)
108
133
 
109
134
  @torch.no_grad
110
- def step(self, vars):
135
+ def step(self, var):
111
136
  main = self.children['main']
112
137
  compare = self.children['compare']
113
138
 
114
- main_vars = main.step(vars.clone(clone_update=True))
115
- vars.update_attrs_from_clone_(main_vars)
139
+ main_var = main.step(var.clone(clone_update=True))
140
+ var.update_attrs_from_clone_(main_var)
116
141
 
117
- compare_vars = compare.step(vars.clone(clone_update=True))
118
- vars.update_attrs_from_clone_(compare_vars)
142
+ compare_var = compare.step(var.clone(clone_update=True))
143
+ var.update_attrs_from_clone_(compare_var)
119
144
 
120
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[vars.params[0]])
121
- vars.update = cautious_(
122
- TensorList(main_vars.get_update()),
123
- TensorList(compare_vars.get_update()),
145
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
146
+ var.update = cautious_(
147
+ TensorList(main_var.get_update()),
148
+ TensorList(compare_var.get_update()),
124
149
  normalize=normalize,
125
150
  mode=mode,
126
151
  eps=eps,
127
152
  )
128
153
 
129
- return vars
154
+ return var
130
155
 
131
156
  class ScaleByGradCosineSimilarity(Transform):
157
+ """Multiplies the update by cosine similarity with gradient.
158
+ If cosine similarity is negative, naturally the update will be negated as well.
159
+
160
+ Args:
161
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
162
+ """
132
163
  def __init__(
133
164
  self,
134
- eps=1e-6,
135
- target: Target = "update",
165
+ eps: float = 1e-6,
136
166
  ):
137
167
  defaults = dict(eps=eps)
138
- super().__init__(defaults, uses_grad=True, target=target)
168
+ super().__init__(defaults, uses_grad=True)
139
169
 
140
170
  @torch.no_grad
141
- def transform(self, tensors, params, grads, vars):
171
+ def apply(self, tensors, params, grads, loss, states, settings):
142
172
  assert grads is not None
143
- eps = self.settings[params[0]]['eps']
173
+ eps = settings[0]['eps']
144
174
  tensors = TensorList(tensors)
145
175
  grads = TensorList(grads)
146
176
  cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
@@ -148,6 +178,14 @@ class ScaleByGradCosineSimilarity(Transform):
148
178
  return tensors.mul_(cos_sim)
149
179
 
150
180
  class ScaleModulesByCosineSimilarity(Module):
181
+ """Scales the output of :code:`main` module by it's cosine similarity to the output
182
+ of :code:`compare` module.
183
+
184
+ Args:
185
+ main (Chainable): main module or sequence of modules whose update will be scaled.
186
+ compare (Chainable): module or sequence of modules to compare to
187
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
188
+ """
151
189
  def __init__(
152
190
  self,
153
191
  main: Chainable,
@@ -161,21 +199,21 @@ class ScaleModulesByCosineSimilarity(Module):
161
199
  self.set_child('compare', compare)
162
200
 
163
201
  @torch.no_grad
164
- def step(self, vars):
202
+ def step(self, var):
165
203
  main = self.children['main']
166
204
  compare = self.children['compare']
167
205
 
168
- main_vars = main.step(vars.clone(clone_update=True))
169
- vars.update_attrs_from_clone_(main_vars)
206
+ main_var = main.step(var.clone(clone_update=True))
207
+ var.update_attrs_from_clone_(main_var)
170
208
 
171
- compare_vars = compare.step(vars.clone(clone_update=True))
172
- vars.update_attrs_from_clone_(compare_vars)
209
+ compare_var = compare.step(var.clone(clone_update=True))
210
+ var.update_attrs_from_clone_(compare_var)
173
211
 
174
- m = TensorList(main_vars.get_update())
175
- c = TensorList(compare_vars.get_update())
176
- eps = self.settings[vars.params[0]]['eps']
212
+ m = TensorList(main_var.get_update())
213
+ c = TensorList(compare_var.get_update())
214
+ eps = self.settings[var.params[0]]['eps']
177
215
 
178
216
  cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
179
217
 
180
- vars.update = m.mul_(cos_sim)
181
- return vars
218
+ var.update = m.mul_(cos_sim)
219
+ return var
@@ -5,18 +5,19 @@ from typing import Literal
5
5
  import torch
6
6
 
7
7
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList
8
+ from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
9
9
  from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
10
10
 
11
11
 
12
12
  class EMA(Transform):
13
- """Maintains EMA of update.
13
+ """Maintains an exponential moving average of update.
14
14
 
15
15
  Args:
16
16
  momentum (float, optional): momentum (beta). Defaults to 0.9.
17
17
  dampening (float, optional): momentum dampening. Defaults to 0.
18
18
  debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
19
  lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
20
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
21
22
  """
22
23
  def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
@@ -24,13 +25,14 @@ class EMA(Transform):
24
25
  super().__init__(defaults, uses_grad=False, target=target)
25
26
 
26
27
  @torch.no_grad
27
- def transform(self, tensors, params, grads, vars):
28
+ def apply(self, tensors, params, grads, loss, states, settings):
28
29
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
30
 
30
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(self.settings[params[0]])
31
+ debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
31
32
 
32
- exp_avg = self.get_state('exp_avg', params=params, init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
33
- momentum, dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
33
+ exp_avg = unpack_states(states, tensors, 'exp_avg',
34
+ init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
+ momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
34
36
 
35
37
  exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
36
38
 
@@ -39,44 +41,58 @@ class EMA(Transform):
39
41
 
40
42
 
41
43
  class EMASquared(Transform):
44
+ """Maintains an exponential moving average of squared updates.
45
+
46
+ Args:
47
+ beta (float, optional): momentum value. Defaults to 0.999.
48
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
49
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
50
+ """
42
51
  EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
43
52
 
44
- def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2, target: Target = 'update'):
53
+ def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
45
54
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
46
- super().__init__(defaults, uses_grad=False, target=target)
55
+ super().__init__(defaults, uses_grad=False)
47
56
 
48
57
  @torch.no_grad
49
- def transform(self, tensors, params, grads, vars):
58
+ def apply(self, tensors, params, grads, loss, states, settings):
50
59
  amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
51
- beta = self.get_settings('beta', params=params, cls=NumberList)
60
+ beta = NumberList(s['beta'] for s in settings)
52
61
 
53
62
  if amsgrad:
54
- exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
63
+ exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
55
64
  else:
56
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
65
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
57
66
  max_exp_avg_sq = None
58
67
 
59
68
  return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
60
69
 
61
70
  class SqrtEMASquared(Transform):
62
- SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
71
+ """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
63
72
 
64
- def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update',):
73
+ Args:
74
+ beta (float, optional): momentum value. Defaults to 0.999.
75
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
76
+ debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
77
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
78
+ """
79
+ SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
80
+ def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
65
81
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
66
- super().__init__(defaults, uses_grad=False, target=target)
82
+ super().__init__(defaults, uses_grad=False)
67
83
 
68
84
 
69
85
  @torch.no_grad
70
- def transform(self, tensors, params, grads, vars):
86
+ def apply(self, tensors, params, grads, loss, states, settings):
71
87
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
88
 
73
- amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
74
- beta = self.get_settings('beta', params=params, cls=NumberList)
89
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
90
+ beta = NumberList(s['beta'] for s in settings)
75
91
 
76
92
  if amsgrad:
77
- exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
93
+ exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
78
94
  else:
79
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
95
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
80
96
  max_exp_avg_sq = None
81
97
 
82
98
  return self.SQRT_EMA_SQ_FN(
@@ -91,47 +107,73 @@ class SqrtEMASquared(Transform):
91
107
 
92
108
 
93
109
  class Debias(Transform):
110
+ """Multiplies the update by an Adam debiasing term based first and/or second momentum.
111
+
112
+ Args:
113
+ beta1 (float | None, optional):
114
+ first momentum, should be the same as first momentum used in modules before. Defaults to None.
115
+ beta2 (float | None, optional):
116
+ second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
117
+ alpha (float, optional): learning rate. Defaults to 1.
118
+ pow (float, optional): power, assumes absolute value is used. Defaults to 2.
119
+ target (Target, optional): target. Defaults to 'update'.
120
+ """
94
121
  def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
95
122
  defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
96
123
  super().__init__(defaults, uses_grad=False, target=target)
97
124
 
98
125
  @torch.no_grad
99
- def transform(self, tensors, params, grads, vars):
126
+ def apply(self, tensors, params, grads, loss, states, settings):
100
127
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
101
128
 
102
- settings = self.settings[params[0]]
103
- pow = settings['pow']
104
- alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
129
+ pow = settings[0]['pow']
130
+ alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)
105
131
 
106
132
  return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
107
133
 
108
134
  class Debias2(Transform):
135
+ """Multiplies the update by an Adam debiasing term based on the second momentum.
136
+
137
+ Args:
138
+ beta (float | None, optional):
139
+ second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
140
+ pow (float, optional): power, assumes absolute value is used. Defaults to 2.
141
+ target (Target, optional): target. Defaults to 'update'.
142
+ """
109
143
  def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
110
144
  defaults = dict(beta=beta, pow=pow)
111
145
  super().__init__(defaults, uses_grad=False, target=target)
112
146
 
113
147
  @torch.no_grad
114
- def transform(self, tensors, params, grads, vars):
148
+ def apply(self, tensors, params, grads, loss, states, settings):
115
149
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
116
150
 
117
- pow = self.settings[params[0]]['pow']
118
- beta = self.get_settings('beta', params=params, cls=NumberList)
151
+ pow = settings[0]['pow']
152
+ beta = NumberList(s['beta'] for s in settings)
119
153
  return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
120
154
 
121
155
  class CenteredEMASquared(Transform):
122
- def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2, target: Target = 'update'):
156
+ """Maintains a centered exponential moving average of squared updates. This also maintains an additional
157
+ exponential moving average of un-squared updates, square of which is subtracted from the EMA.
158
+
159
+ Args:
160
+ beta (float, optional): momentum value. Defaults to 0.999.
161
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
162
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
163
+ """
164
+ def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
123
165
  defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
124
- super().__init__(defaults, uses_grad=False, target=target)
166
+ super().__init__(defaults, uses_grad=False)
125
167
 
126
168
  @torch.no_grad
127
- def transform(self, tensors, params, grads, vars):
128
- amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
129
- beta = self.get_settings('beta', params=params, cls=NumberList)
169
+ def apply(self, tensors, params, grads, loss, states, settings):
170
+ amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
171
+ beta = NumberList(s['beta'] for s in settings)
130
172
 
131
173
  if amsgrad:
132
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
174
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
133
175
  else:
134
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
176
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
135
177
  max_exp_avg_sq = None
136
178
 
137
179
  return centered_ema_sq_(
@@ -144,21 +186,30 @@ class CenteredEMASquared(Transform):
144
186
  ).clone()
145
187
 
146
188
  class CenteredSqrtEMASquared(Transform):
147
- def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update'):
189
+ """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
190
+ This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
191
+
192
+ Args:
193
+ beta (float, optional): momentum value. Defaults to 0.999.
194
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
195
+ debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
196
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
197
+ """
198
+ def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
148
199
  defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
149
- super().__init__(defaults, uses_grad=False, target=target)
200
+ super().__init__(defaults, uses_grad=False)
150
201
 
151
202
  @torch.no_grad
152
- def transform(self, tensors, params, grads, vars):
203
+ def apply(self, tensors, params, grads, loss, states, settings):
153
204
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
154
205
 
155
- amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
156
- beta = self.get_settings('beta', params=params, cls=NumberList)
206
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
207
+ beta = NumberList(s['beta'] for s in settings)
157
208
 
158
209
  if amsgrad:
159
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
210
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
160
211
  else:
161
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
212
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
162
213
  max_exp_avg_sq = None
163
214
 
164
215
  return sqrt_centered_ema_sq_(
@@ -6,7 +6,7 @@ from typing import Literal
6
6
  import torch
7
7
 
8
8
  from ...core import Target, Transform
9
- from ...utils import NumberList, TensorList
9
+ from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
10
10
  from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
11
  from .ema import EMASquared, SqrtEMASquared
12
12
  from .momentum import nag_
@@ -43,22 +43,22 @@ def precentered_ema_sq_(
43
43
  return exp_avg_sq_
44
44
 
45
45
  class PrecenteredEMASquared(Transform):
46
+ """Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
46
47
  def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
47
48
  defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
48
49
  super().__init__(defaults, uses_grad=False, target=target)
49
- self.current_step = 0
50
50
 
51
51
  @torch.no_grad
52
- def transform(self, tensors, params, grads, vars):
53
- self.current_step += 1
52
+ def apply(self, tensors, params, grads, loss, states, settings):
53
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
54
54
 
55
- beta1, beta2 = self.get_settings('beta1','beta2', params=params, cls=NumberList)
56
- amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(self.settings[params[0]])
55
+ beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
56
+ amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
57
57
 
58
58
  if amsgrad:
59
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
59
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
60
60
  else:
61
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
61
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
62
62
  max_exp_avg_sq = None
63
63
 
64
64
  return precentered_ema_sq_(
@@ -67,7 +67,7 @@ class PrecenteredEMASquared(Transform):
67
67
  exp_avg_sq_=exp_avg_sq,
68
68
  beta1=beta1,
69
69
  beta2=beta2,
70
- step = self.current_step,
70
+ step = step,
71
71
  min_step=min_step,
72
72
  pow=pow,
73
73
  max_exp_avg_sq_=max_exp_avg_sq,
@@ -119,9 +119,11 @@ def sqrt_nag_ema_sq_(
119
119
  pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
120
120
 
121
121
  class NesterovEMASquared(EMASquared):
122
+ """squared momentum with nesterov momentum rule"""
122
123
  EMA_SQ_FN = staticmethod(nag_ema_sq_)
123
124
 
124
125
  class SqrtNesterovEMASquared(SqrtEMASquared):
126
+ """square root of squared momentum with nesterov momentum rule"""
125
127
  SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
126
128
 
127
129
 
@@ -141,14 +143,20 @@ def coordinate_momentum_(
141
143
 
142
144
 
143
145
  class CoordinateMomentum(Transform):
146
+ """Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
147
+
148
+ Args:
149
+ p (float, optional): _description_. Defaults to 0.1.
150
+ target (Target, optional): _description_. Defaults to 'update'.
151
+ """
144
152
  def __init__(self, p: float = 0.1, target: Target = 'update'):
145
153
  defaults = dict(p=p)
146
154
  super().__init__(defaults, uses_grad=False, target=target)
147
155
 
148
156
  @torch.no_grad
149
- def transform(self, tensors, params, grads, vars):
150
- p = self.get_settings('p', params=params, cls=NumberList)
151
- velocity = self.get_state('velocity', params=params, cls=TensorList)
157
+ def apply(self, tensors, params, grads, loss, states, settings):
158
+ p = NumberList(s['p'] for s in settings)
159
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
152
160
  return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
153
161
 
154
162
 
@@ -180,7 +188,7 @@ class CoordinateMomentum(Transform):
180
188
  # super().__init__(defaults, uses_grad=False)
181
189
 
182
190
  # @torch.no_grad
183
- # def transform(self, tensors, params, grads, vars):
191
+ # def apply(self, tensors, params, grads, loss, states, settings):
184
192
  # momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
185
193
  # abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
186
194
  # velocity = self.get_state('velocity', params=params, cls=TensorList)