torchzero 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +115 -68
  5. tests/test_tensorlist.py +2 -2
  6. tests/test_vars.py +62 -61
  7. torchzero/core/__init__.py +2 -3
  8. torchzero/core/module.py +185 -53
  9. torchzero/core/transform.py +327 -159
  10. torchzero/modules/__init__.py +3 -1
  11. torchzero/modules/clipping/clipping.py +120 -23
  12. torchzero/modules/clipping/ema_clipping.py +37 -22
  13. torchzero/modules/clipping/growth_clipping.py +20 -21
  14. torchzero/modules/experimental/__init__.py +30 -4
  15. torchzero/modules/experimental/absoap.py +53 -156
  16. torchzero/modules/experimental/adadam.py +22 -15
  17. torchzero/modules/experimental/adamY.py +21 -25
  18. torchzero/modules/experimental/adam_lambertw.py +149 -0
  19. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +37 -8
  20. torchzero/modules/experimental/adasoap.py +24 -129
  21. torchzero/modules/experimental/cosine.py +214 -0
  22. torchzero/modules/experimental/cubic_adam.py +97 -0
  23. torchzero/modules/experimental/curveball.py +12 -12
  24. torchzero/modules/{projections → experimental}/dct.py +11 -11
  25. torchzero/modules/experimental/eigendescent.py +120 -0
  26. torchzero/modules/experimental/etf.py +195 -0
  27. torchzero/modules/experimental/exp_adam.py +113 -0
  28. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  29. torchzero/modules/{projections → experimental}/fft.py +10 -10
  30. torchzero/modules/experimental/gradmin.py +2 -2
  31. torchzero/modules/experimental/hnewton.py +85 -0
  32. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +49 -50
  33. torchzero/modules/experimental/newton_solver.py +11 -11
  34. torchzero/modules/experimental/newtonnewton.py +92 -0
  35. torchzero/modules/experimental/parabolic_search.py +220 -0
  36. torchzero/modules/experimental/reduce_outward_lr.py +10 -7
  37. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  38. torchzero/modules/experimental/subspace_preconditioners.py +20 -10
  39. torchzero/modules/experimental/tensor_adagrad.py +42 -0
  40. torchzero/modules/functional.py +12 -2
  41. torchzero/modules/grad_approximation/fdm.py +31 -4
  42. torchzero/modules/grad_approximation/forward_gradient.py +17 -7
  43. torchzero/modules/grad_approximation/grad_approximator.py +69 -24
  44. torchzero/modules/grad_approximation/rfdm.py +310 -50
  45. torchzero/modules/higher_order/__init__.py +1 -0
  46. torchzero/modules/higher_order/higher_order_newton.py +319 -0
  47. torchzero/modules/line_search/__init__.py +4 -4
  48. torchzero/modules/line_search/adaptive.py +99 -0
  49. torchzero/modules/line_search/backtracking.py +75 -31
  50. torchzero/modules/line_search/line_search.py +107 -49
  51. torchzero/modules/line_search/polynomial.py +233 -0
  52. torchzero/modules/line_search/scipy.py +20 -5
  53. torchzero/modules/line_search/strong_wolfe.py +52 -36
  54. torchzero/modules/misc/__init__.py +27 -0
  55. torchzero/modules/misc/debug.py +48 -0
  56. torchzero/modules/misc/escape.py +60 -0
  57. torchzero/modules/misc/gradient_accumulation.py +70 -0
  58. torchzero/modules/misc/misc.py +316 -0
  59. torchzero/modules/misc/multistep.py +158 -0
  60. torchzero/modules/misc/regularization.py +171 -0
  61. torchzero/modules/misc/split.py +103 -0
  62. torchzero/modules/{ops → misc}/switch.py +48 -7
  63. torchzero/modules/momentum/__init__.py +1 -1
  64. torchzero/modules/momentum/averaging.py +25 -10
  65. torchzero/modules/momentum/cautious.py +115 -40
  66. torchzero/modules/momentum/ema.py +92 -41
  67. torchzero/modules/momentum/experimental.py +21 -13
  68. torchzero/modules/momentum/matrix_momentum.py +145 -76
  69. torchzero/modules/momentum/momentum.py +25 -4
  70. torchzero/modules/ops/__init__.py +3 -31
  71. torchzero/modules/ops/accumulate.py +51 -25
  72. torchzero/modules/ops/binary.py +108 -62
  73. torchzero/modules/ops/multi.py +95 -34
  74. torchzero/modules/ops/reduce.py +31 -23
  75. torchzero/modules/ops/unary.py +37 -21
  76. torchzero/modules/ops/utility.py +53 -45
  77. torchzero/modules/optimizers/__init__.py +12 -3
  78. torchzero/modules/optimizers/adagrad.py +48 -29
  79. torchzero/modules/optimizers/adahessian.py +223 -0
  80. torchzero/modules/optimizers/adam.py +35 -37
  81. torchzero/modules/optimizers/adan.py +110 -0
  82. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  83. torchzero/modules/optimizers/esgd.py +171 -0
  84. torchzero/modules/optimizers/ladagrad.py +183 -0
  85. torchzero/modules/optimizers/lion.py +4 -4
  86. torchzero/modules/optimizers/mars.py +91 -0
  87. torchzero/modules/optimizers/msam.py +186 -0
  88. torchzero/modules/optimizers/muon.py +32 -7
  89. torchzero/modules/optimizers/orthograd.py +4 -5
  90. torchzero/modules/optimizers/rmsprop.py +19 -19
  91. torchzero/modules/optimizers/rprop.py +89 -52
  92. torchzero/modules/optimizers/sam.py +163 -0
  93. torchzero/modules/optimizers/shampoo.py +55 -27
  94. torchzero/modules/optimizers/soap.py +40 -37
  95. torchzero/modules/optimizers/sophia_h.py +82 -25
  96. torchzero/modules/projections/__init__.py +2 -4
  97. torchzero/modules/projections/cast.py +51 -0
  98. torchzero/modules/projections/galore.py +4 -2
  99. torchzero/modules/projections/projection.py +212 -118
  100. torchzero/modules/quasi_newton/__init__.py +44 -5
  101. torchzero/modules/quasi_newton/cg.py +190 -39
  102. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  103. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  104. torchzero/modules/quasi_newton/lsr1.py +102 -58
  105. torchzero/modules/quasi_newton/quasi_newton.py +1032 -177
  106. torchzero/modules/quasi_newton/trust_region.py +397 -0
  107. torchzero/modules/second_order/__init__.py +2 -2
  108. torchzero/modules/second_order/newton.py +245 -54
  109. torchzero/modules/second_order/newton_cg.py +311 -21
  110. torchzero/modules/second_order/nystrom.py +124 -21
  111. torchzero/modules/smoothing/gaussian.py +55 -21
  112. torchzero/modules/smoothing/laplacian.py +20 -12
  113. torchzero/modules/step_size/__init__.py +2 -0
  114. torchzero/modules/step_size/adaptive.py +122 -0
  115. torchzero/modules/step_size/lr.py +154 -0
  116. torchzero/modules/weight_decay/__init__.py +1 -1
  117. torchzero/modules/weight_decay/weight_decay.py +126 -10
  118. torchzero/modules/wrappers/optim_wrapper.py +40 -12
  119. torchzero/optim/wrappers/directsearch.py +281 -0
  120. torchzero/optim/wrappers/fcmaes.py +105 -0
  121. torchzero/optim/wrappers/mads.py +89 -0
  122. torchzero/optim/wrappers/nevergrad.py +20 -5
  123. torchzero/optim/wrappers/nlopt.py +28 -14
  124. torchzero/optim/wrappers/optuna.py +70 -0
  125. torchzero/optim/wrappers/scipy.py +167 -16
  126. torchzero/utils/__init__.py +3 -7
  127. torchzero/utils/derivatives.py +5 -4
  128. torchzero/utils/linalg/__init__.py +1 -1
  129. torchzero/utils/linalg/solve.py +251 -12
  130. torchzero/utils/numberlist.py +2 -0
  131. torchzero/utils/optimizer.py +55 -74
  132. torchzero/utils/python_tools.py +27 -4
  133. torchzero/utils/tensorlist.py +40 -28
  134. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/METADATA +76 -51
  135. torchzero-0.3.11.dist-info/RECORD +159 -0
  136. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/WHEEL +1 -1
  137. torchzero/core/preconditioner.py +0 -138
  138. torchzero/modules/experimental/algebraic_newton.py +0 -145
  139. torchzero/modules/experimental/soapy.py +0 -290
  140. torchzero/modules/experimental/spectral.py +0 -288
  141. torchzero/modules/experimental/structured_newton.py +0 -111
  142. torchzero/modules/experimental/tropical_newton.py +0 -136
  143. torchzero/modules/lr/__init__.py +0 -2
  144. torchzero/modules/lr/lr.py +0 -59
  145. torchzero/modules/lr/step_size.py +0 -97
  146. torchzero/modules/ops/debug.py +0 -25
  147. torchzero/modules/ops/misc.py +0 -419
  148. torchzero/modules/ops/split.py +0 -75
  149. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  150. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  151. torchzero-0.3.9.dist-info/RECORD +0 -131
  152. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  153. {torchzero-0.3.9.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ """Cautioning related modules"""
1
2
  from collections import deque
2
3
  from operator import itemgetter
3
4
  from typing import Literal
@@ -5,7 +6,7 @@ from typing import Literal
5
6
  import torch
6
7
 
7
8
  from ...core import Target, Transform, Module, Chainable
8
- from ...utils import NumberList, TensorList
9
+ from ...utils import NumberList, TensorList, unpack_dicts
9
10
 
10
11
 
11
12
  def cautious_(
@@ -54,9 +55,20 @@ class Cautious(Transform):
54
55
 
55
56
  "backtrack" - negate them (same as using update magnitude and gradient sign)
56
57
 
57
- reference
58
- *Cautious Optimizers: Improving Training with One Line of Code.
59
- Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
58
+ Examples:
59
+ Cautious Adam
60
+
61
+ .. code-block:: python
62
+
63
+ opt = tz.Modular(
64
+ bench.parameters(),
65
+ tz.m.Adam(),
66
+ tz.m.Cautious(),
67
+ tz.m.LR(1e-2)
68
+ )
69
+
70
+ References:
71
+ Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
60
72
  """
61
73
 
62
74
  def __init__(
@@ -64,27 +76,33 @@ class Cautious(Transform):
64
76
  normalize=False,
65
77
  eps=1e-6,
66
78
  mode: Literal["zero", "grad", "backtrack"] = "zero",
67
- target: Target = "update",
68
79
  ):
69
80
  defaults = dict(normalize=normalize, eps=eps, mode=mode)
70
- super().__init__(defaults, uses_grad=True, target=target)
81
+ super().__init__(defaults, uses_grad=True)
71
82
 
72
83
  @torch.no_grad
73
- def transform(self, tensors, params, grads, vars):
84
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
74
85
  assert grads is not None
75
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[params[0]])
86
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
76
87
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
77
88
 
78
89
  class UpdateGradientSignConsistency(Transform):
79
- """1 where signs match 0 otherwise"""
80
- def __init__(self, normalize = False, eps=1e-6, target: Target = 'update'):
90
+ """Compares update and gradient signs. Output will have 1s where signs match, and 0s where they don't.
91
+
92
+ Args:
93
+ normalize (bool, optional):
94
+ renormalize update after masking. Defaults to False.
95
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
96
+ """
97
+ def __init__(self, normalize = False, eps=1e-6):
98
+
81
99
  defaults = dict(normalize=normalize, eps=eps)
82
- super().__init__(defaults, uses_grad=True, target=target)
100
+ super().__init__(defaults, uses_grad=True)
83
101
 
84
102
  @torch.no_grad
85
- def transform(self, tensors, params, grads, vars):
103
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
86
104
  assert grads is not None
87
- normalize, eps = itemgetter('normalize', 'eps')(self.settings[params[0]])
105
+ normalize, eps = itemgetter('normalize', 'eps')(settings[0])
88
106
 
89
107
  mask = (TensorList(tensors).mul_(grads)).gt_(0)
90
108
  if normalize: mask = mask / mask.global_mean().clip(min = eps) # pyright: ignore[reportOperatorIssue]
@@ -92,6 +110,23 @@ class UpdateGradientSignConsistency(Transform):
92
110
  return mask
93
111
 
94
112
  class IntermoduleCautious(Module):
113
+ """Negaties update on :code:`main` module where it's sign doesn't match with output of :code:`compare` module.
114
+
115
+ Args:
116
+ main (Chainable): main module or sequence of modules whose update will be cautioned.
117
+ compare (Chainable): modules or sequence of modules to compare the sign to.
118
+ normalize (bool, optional):
119
+ renormalize update after masking. Defaults to False.
120
+ eps (float, optional): epsilon for normalization. Defaults to 1e-6.
121
+ mode (str, optional):
122
+ what to do with updates with inconsistent signs.
123
+
124
+ "zero" - set them to zero (as in paper)
125
+
126
+ "grad" - set them to the gradient
127
+
128
+ "backtrack" - negate them (same as using update magnitude and gradient sign)
129
+ """
95
130
  def __init__(
96
131
  self,
97
132
  main: Chainable,
@@ -100,6 +135,7 @@ class IntermoduleCautious(Module):
100
135
  eps=1e-6,
101
136
  mode: Literal["zero", "grad", "backtrack"] = "zero",
102
137
  ):
138
+
103
139
  defaults = dict(normalize=normalize, eps=eps, mode=mode)
104
140
  super().__init__(defaults)
105
141
 
@@ -107,47 +143,86 @@ class IntermoduleCautious(Module):
107
143
  self.set_child('compare', compare)
108
144
 
109
145
  @torch.no_grad
110
- def step(self, vars):
146
+ def step(self, var):
111
147
  main = self.children['main']
112
148
  compare = self.children['compare']
113
149
 
114
- main_vars = main.step(vars.clone(clone_update=True))
115
- vars.update_attrs_from_clone_(main_vars)
150
+ main_var = main.step(var.clone(clone_update=True))
151
+ var.update_attrs_from_clone_(main_var)
116
152
 
117
- compare_vars = compare.step(vars.clone(clone_update=True))
118
- vars.update_attrs_from_clone_(compare_vars)
153
+ compare_var = compare.step(var.clone(clone_update=True))
154
+ var.update_attrs_from_clone_(compare_var)
119
155
 
120
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[vars.params[0]])
121
- vars.update = cautious_(
122
- TensorList(main_vars.get_update()),
123
- TensorList(compare_vars.get_update()),
156
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
157
+ var.update = cautious_(
158
+ TensorList(main_var.get_update()),
159
+ TensorList(compare_var.get_update()),
124
160
  normalize=normalize,
125
161
  mode=mode,
126
162
  eps=eps,
127
163
  )
128
164
 
129
- return vars
165
+ return var
130
166
 
131
167
  class ScaleByGradCosineSimilarity(Transform):
168
+ """Multiplies the update by cosine similarity with gradient.
169
+ If cosine similarity is negative, naturally the update will be negated as well.
170
+
171
+ Args:
172
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
173
+
174
+ Examples:
175
+ Scaled Adam
176
+
177
+ .. code-block:: python
178
+
179
+ opt = tz.Modular(
180
+ bench.parameters(),
181
+ tz.m.Adam(),
182
+ tz.m.ScaleByGradCosineSimilarity(),
183
+ tz.m.LR(1e-2)
184
+ )
185
+ """
132
186
  def __init__(
133
187
  self,
134
- eps=1e-6,
135
- target: Target = "update",
188
+ eps: float = 1e-6,
136
189
  ):
137
190
  defaults = dict(eps=eps)
138
- super().__init__(defaults, uses_grad=True, target=target)
191
+ super().__init__(defaults, uses_grad=True)
139
192
 
140
193
  @torch.no_grad
141
- def transform(self, tensors, params, grads, vars):
194
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
142
195
  assert grads is not None
143
- eps = self.settings[params[0]]['eps']
196
+ eps = settings[0]['eps']
144
197
  tensors = TensorList(tensors)
145
198
  grads = TensorList(grads)
146
- cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
199
+ cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
147
200
 
148
201
  return tensors.mul_(cos_sim)
149
202
 
150
203
  class ScaleModulesByCosineSimilarity(Module):
204
+ """Scales the output of :code:`main` module by it's cosine similarity to the output
205
+ of :code:`compare` module.
206
+
207
+ Args:
208
+ main (Chainable): main module or sequence of modules whose update will be scaled.
209
+ compare (Chainable): module or sequence of modules to compare to
210
+ eps (float, optional): epsilon for division. Defaults to 1e-6.
211
+
212
+ Example:
213
+ Adam scaled by similarity to RMSprop
214
+
215
+ .. code-block:: python
216
+
217
+ opt = tz.Modular(
218
+ bench.parameters(),
219
+ tz.m.ScaleModulesByCosineSimilarity(
220
+ main = tz.m.Adam(),
221
+ compare = tz.m.RMSprop(0.999, debiased=True),
222
+ ),
223
+ tz.m.LR(1e-2)
224
+ )
225
+ """
151
226
  def __init__(
152
227
  self,
153
228
  main: Chainable,
@@ -161,21 +236,21 @@ class ScaleModulesByCosineSimilarity(Module):
161
236
  self.set_child('compare', compare)
162
237
 
163
238
  @torch.no_grad
164
- def step(self, vars):
239
+ def step(self, var):
165
240
  main = self.children['main']
166
241
  compare = self.children['compare']
167
242
 
168
- main_vars = main.step(vars.clone(clone_update=True))
169
- vars.update_attrs_from_clone_(main_vars)
243
+ main_var = main.step(var.clone(clone_update=True))
244
+ var.update_attrs_from_clone_(main_var)
170
245
 
171
- compare_vars = compare.step(vars.clone(clone_update=True))
172
- vars.update_attrs_from_clone_(compare_vars)
246
+ compare_var = compare.step(var.clone(clone_update=True))
247
+ var.update_attrs_from_clone_(compare_var)
173
248
 
174
- m = TensorList(main_vars.get_update())
175
- c = TensorList(compare_vars.get_update())
176
- eps = self.settings[vars.params[0]]['eps']
249
+ m = TensorList(main_var.get_update())
250
+ c = TensorList(compare_var.get_update())
251
+ eps = self.settings[var.params[0]]['eps']
177
252
 
178
- cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
253
+ cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
179
254
 
180
- vars.update = m.mul_(cos_sim)
181
- return vars
255
+ var.update = m.mul_(cos_sim)
256
+ return var
@@ -5,18 +5,19 @@ from typing import Literal
5
5
  import torch
6
6
 
7
7
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList
8
+ from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
9
9
  from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
10
10
 
11
11
 
12
12
  class EMA(Transform):
13
- """Maintains EMA of update.
13
+ """Maintains an exponential moving average of update.
14
14
 
15
15
  Args:
16
16
  momentum (float, optional): momentum (beta). Defaults to 0.9.
17
17
  dampening (float, optional): momentum dampening. Defaults to 0.
18
18
  debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
19
  lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
20
21
  target (Target, optional): target to apply EMA to. Defaults to 'update'.
21
22
  """
22
23
  def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
@@ -24,13 +25,14 @@ class EMA(Transform):
24
25
  super().__init__(defaults, uses_grad=False, target=target)
25
26
 
26
27
  @torch.no_grad
27
- def transform(self, tensors, params, grads, vars):
28
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
28
29
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
29
30
 
30
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(self.settings[params[0]])
31
+ debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
31
32
 
32
- exp_avg = self.get_state('exp_avg', params=params, init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
33
- momentum, dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
33
+ exp_avg = unpack_states(states, tensors, 'exp_avg',
34
+ init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
+ momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
34
36
 
35
37
  exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
36
38
 
@@ -39,44 +41,58 @@ class EMA(Transform):
39
41
 
40
42
 
41
43
  class EMASquared(Transform):
44
+ """Maintains an exponential moving average of squared updates.
45
+
46
+ Args:
47
+ beta (float, optional): momentum value. Defaults to 0.999.
48
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
49
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
50
+ """
42
51
  EMA_SQ_FN: staticmethod = staticmethod(ema_sq_)
43
52
 
44
- def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2, target: Target = 'update'):
53
+ def __init__(self, beta:float=0.999, amsgrad=False, pow:float=2):
45
54
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad)
46
- super().__init__(defaults, uses_grad=False, target=target)
55
+ super().__init__(defaults, uses_grad=False)
47
56
 
48
57
  @torch.no_grad
49
- def transform(self, tensors, params, grads, vars):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
50
59
  amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
51
- beta = self.get_settings('beta', params=params, cls=NumberList)
60
+ beta = NumberList(s['beta'] for s in settings)
52
61
 
53
62
  if amsgrad:
54
- exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
63
+ exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
55
64
  else:
56
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
65
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
57
66
  max_exp_avg_sq = None
58
67
 
59
68
  return self.EMA_SQ_FN(TensorList(tensors), exp_avg_sq_=exp_avg_sq, beta=beta, max_exp_avg_sq_=max_exp_avg_sq, pow=pow).clone()
60
69
 
61
70
  class SqrtEMASquared(Transform):
62
- SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
71
+ """Maintains an exponential moving average of squared updates, outputs optionally debiased square root.
63
72
 
64
- def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update',):
73
+ Args:
74
+ beta (float, optional): momentum value. Defaults to 0.999.
75
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
76
+ debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
77
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
78
+ """
79
+ SQRT_EMA_SQ_FN: staticmethod = staticmethod(sqrt_ema_sq_)
80
+ def __init__(self, beta:float=0.999, amsgrad=False, debiased: bool = False, pow:float=2,):
65
81
  defaults = dict(beta=beta,pow=pow,amsgrad=amsgrad,debiased=debiased)
66
- super().__init__(defaults, uses_grad=False, target=target)
82
+ super().__init__(defaults, uses_grad=False)
67
83
 
68
84
 
69
85
  @torch.no_grad
70
- def transform(self, tensors, params, grads, vars):
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
71
87
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
72
88
 
73
- amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
74
- beta = self.get_settings('beta', params=params, cls=NumberList)
89
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
90
+ beta = NumberList(s['beta'] for s in settings)
75
91
 
76
92
  if amsgrad:
77
- exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
93
+ exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
78
94
  else:
79
- exp_avg_sq = self.get_state('exp_avg_sq', params=params, cls=TensorList)
95
+ exp_avg_sq = unpack_states(states, tensors, 'exp_avg_sq', cls=TensorList)
80
96
  max_exp_avg_sq = None
81
97
 
82
98
  return self.SQRT_EMA_SQ_FN(
@@ -91,47 +107,73 @@ class SqrtEMASquared(Transform):
91
107
 
92
108
 
93
109
  class Debias(Transform):
110
+ """Multiplies the update by an Adam debiasing term based first and/or second momentum.
111
+
112
+ Args:
113
+ beta1 (float | None, optional):
114
+ first momentum, should be the same as first momentum used in modules before. Defaults to None.
115
+ beta2 (float | None, optional):
116
+ second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
117
+ alpha (float, optional): learning rate. Defaults to 1.
118
+ pow (float, optional): power, assumes absolute value is used. Defaults to 2.
119
+ target (Target, optional): target. Defaults to 'update'.
120
+ """
94
121
  def __init__(self, beta1: float | None = None, beta2: float | None = None, alpha: float = 1, pow:float=2, target: Target = 'update',):
95
122
  defaults = dict(beta1=beta1, beta2=beta2, alpha=alpha, pow=pow)
96
123
  super().__init__(defaults, uses_grad=False, target=target)
97
124
 
98
125
  @torch.no_grad
99
- def transform(self, tensors, params, grads, vars):
126
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
100
127
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
101
128
 
102
- settings = self.settings[params[0]]
103
- pow = settings['pow']
104
- alpha, beta1, beta2 = self.get_settings('alpha', 'beta1', 'beta2', params=params, cls=NumberList)
129
+ pow = settings[0]['pow']
130
+ alpha, beta1, beta2 = unpack_dicts(settings, 'alpha', 'beta1', 'beta2', cls=NumberList)
105
131
 
106
132
  return debias(TensorList(tensors), step=step, beta1=beta1, beta2=beta2, alpha=alpha, pow=pow, inplace=True)
107
133
 
108
134
  class Debias2(Transform):
135
+ """Multiplies the update by an Adam debiasing term based on the second momentum.
136
+
137
+ Args:
138
+ beta (float | None, optional):
139
+ second (squared) momentum, should be the same as second momentum used in modules before. Defaults to None.
140
+ pow (float, optional): power, assumes absolute value is used. Defaults to 2.
141
+ target (Target, optional): target. Defaults to 'update'.
142
+ """
109
143
  def __init__(self, beta: float = 0.999, pow: float = 2, target: Target = 'update',):
110
144
  defaults = dict(beta=beta, pow=pow)
111
145
  super().__init__(defaults, uses_grad=False, target=target)
112
146
 
113
147
  @torch.no_grad
114
- def transform(self, tensors, params, grads, vars):
148
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
115
149
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
116
150
 
117
- pow = self.settings[params[0]]['pow']
118
- beta = self.get_settings('beta', params=params, cls=NumberList)
151
+ pow = settings[0]['pow']
152
+ beta = NumberList(s['beta'] for s in settings)
119
153
  return debias_second_momentum(TensorList(tensors), step=step, beta=beta, pow=pow, inplace=True)
120
154
 
121
155
  class CenteredEMASquared(Transform):
122
- def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2, target: Target = 'update'):
156
+ """Maintains a centered exponential moving average of squared updates. This also maintains an additional
157
+ exponential moving average of un-squared updates, square of which is subtracted from the EMA.
158
+
159
+ Args:
160
+ beta (float, optional): momentum value. Defaults to 0.999.
161
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
162
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
163
+ """
164
+ def __init__(self, beta: float = 0.99, amsgrad=False, pow:float=2):
123
165
  defaults = dict(beta=beta, amsgrad=amsgrad, pow=pow)
124
- super().__init__(defaults, uses_grad=False, target=target)
166
+ super().__init__(defaults, uses_grad=False)
125
167
 
126
168
  @torch.no_grad
127
- def transform(self, tensors, params, grads, vars):
128
- amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
129
- beta = self.get_settings('beta', params=params, cls=NumberList)
169
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
170
+ amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
171
+ beta = NumberList(s['beta'] for s in settings)
130
172
 
131
173
  if amsgrad:
132
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
174
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
133
175
  else:
134
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
176
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
135
177
  max_exp_avg_sq = None
136
178
 
137
179
  return centered_ema_sq_(
@@ -144,21 +186,30 @@ class CenteredEMASquared(Transform):
144
186
  ).clone()
145
187
 
146
188
  class CenteredSqrtEMASquared(Transform):
147
- def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2, target: Target = 'update'):
189
+ """Maintains a centered exponential moving average of squared updates, outputs optionally debiased square root.
190
+ This also maintains an additional exponential moving average of un-squared updates, square of which is subtracted from the EMA.
191
+
192
+ Args:
193
+ beta (float, optional): momentum value. Defaults to 0.999.
194
+ amsgrad (bool, optional): whether to maintain maximum of the exponential moving average. Defaults to False.
195
+ debiased (bool, optional): whether to multiply the output by a debiasing term from the Adam method. Defaults to False.
196
+ pow (float, optional): power, absolute value is always used. Defaults to 2.
197
+ """
198
+ def __init__(self, beta: float = 0.99, amsgrad=False, debiased: bool = False, pow:float=2):
148
199
  defaults = dict(beta=beta, amsgrad=amsgrad, debiased=debiased, pow=pow)
149
- super().__init__(defaults, uses_grad=False, target=target)
200
+ super().__init__(defaults, uses_grad=False)
150
201
 
151
202
  @torch.no_grad
152
- def transform(self, tensors, params, grads, vars):
203
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
153
204
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
154
205
 
155
- amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(self.settings[params[0]])
156
- beta = self.get_settings('beta', params=params, cls=NumberList)
206
+ amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
207
+ beta = NumberList(s['beta'] for s in settings)
157
208
 
158
209
  if amsgrad:
159
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
210
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
160
211
  else:
161
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
212
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
162
213
  max_exp_avg_sq = None
163
214
 
164
215
  return sqrt_centered_ema_sq_(
@@ -6,7 +6,7 @@ from typing import Literal
6
6
  import torch
7
7
 
8
8
  from ...core import Target, Transform
9
- from ...utils import NumberList, TensorList
9
+ from ...utils import NumberList, TensorList, unpack_states, unpack_dicts
10
10
  from ..functional import ema_, ema_sq_, sqrt_ema_sq_
11
11
  from .ema import EMASquared, SqrtEMASquared
12
12
  from .momentum import nag_
@@ -43,22 +43,22 @@ def precentered_ema_sq_(
43
43
  return exp_avg_sq_
44
44
 
45
45
  class PrecenteredEMASquared(Transform):
46
+ """Maintains un-squared EMA, the updates are centered by it before being fed into squared EMA."""
46
47
  def __init__(self, beta1:float=0.99, beta2=0.99, min_step: int = 2, amsgrad=False, pow:float=2, target: Target = 'update'):
47
48
  defaults = dict(beta1=beta1,beta2=beta2,pow=pow,amsgrad=amsgrad, min_step=min_step)
48
49
  super().__init__(defaults, uses_grad=False, target=target)
49
- self.current_step = 0
50
50
 
51
51
  @torch.no_grad
52
- def transform(self, tensors, params, grads, vars):
53
- self.current_step += 1
52
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
54
54
 
55
- beta1, beta2 = self.get_settings('beta1','beta2', params=params, cls=NumberList)
56
- amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(self.settings[params[0]])
55
+ beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
56
+ amsgrad, pow, min_step = itemgetter('amsgrad', 'pow', 'min_step')(settings[0])
57
57
 
58
58
  if amsgrad:
59
- exp_avg, exp_avg_sq, max_exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', params=params, cls=TensorList)
59
+ exp_avg, exp_avg_sq, max_exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', 'max_exp_avg_sq', cls=TensorList)
60
60
  else:
61
- exp_avg, exp_avg_sq = self.get_state('exp_avg', 'exp_avg_sq', params=params, cls=TensorList)
61
+ exp_avg, exp_avg_sq = unpack_states(states, tensors, 'exp_avg', 'exp_avg_sq', cls=TensorList)
62
62
  max_exp_avg_sq = None
63
63
 
64
64
  return precentered_ema_sq_(
@@ -67,7 +67,7 @@ class PrecenteredEMASquared(Transform):
67
67
  exp_avg_sq_=exp_avg_sq,
68
68
  beta1=beta1,
69
69
  beta2=beta2,
70
- step = self.current_step,
70
+ step = step,
71
71
  min_step=min_step,
72
72
  pow=pow,
73
73
  max_exp_avg_sq_=max_exp_avg_sq,
@@ -119,9 +119,11 @@ def sqrt_nag_ema_sq_(
119
119
  pow=pow,debiased=debiased,step=step,ema_sq_fn=partial(nag_ema_sq_,lerp=lerp))
120
120
 
121
121
  class NesterovEMASquared(EMASquared):
122
+ """squared momentum with nesterov momentum rule"""
122
123
  EMA_SQ_FN = staticmethod(nag_ema_sq_)
123
124
 
124
125
  class SqrtNesterovEMASquared(SqrtEMASquared):
126
+ """square root of squared momentum with nesterov momentum rule"""
125
127
  SQRT_EMA_SQ_FN = staticmethod(sqrt_nag_ema_sq_)
126
128
 
127
129
 
@@ -141,14 +143,20 @@ def coordinate_momentum_(
141
143
 
142
144
 
143
145
  class CoordinateMomentum(Transform):
146
+ """Maintains a momentum buffer, on each step each value in the buffer has :code:`p` chance to be updated with the new value.
147
+
148
+ Args:
149
+ p (float, optional): _description_. Defaults to 0.1.
150
+ target (Target, optional): _description_. Defaults to 'update'.
151
+ """
144
152
  def __init__(self, p: float = 0.1, target: Target = 'update'):
145
153
  defaults = dict(p=p)
146
154
  super().__init__(defaults, uses_grad=False, target=target)
147
155
 
148
156
  @torch.no_grad
149
- def transform(self, tensors, params, grads, vars):
150
- p = self.get_settings('p', params=params, cls=NumberList)
151
- velocity = self.get_state('velocity', params=params, cls=TensorList)
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
+ p = NumberList(s['p'] for s in settings)
159
+ velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
152
160
  return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
153
161
 
154
162
 
@@ -180,7 +188,7 @@ class CoordinateMomentum(Transform):
180
188
  # super().__init__(defaults, uses_grad=False)
181
189
 
182
190
  # @torch.no_grad
183
- # def transform(self, tensors, params, grads, vars):
191
+ # def apply(self, tensors, params, grads, loss, states, settings):
184
192
  # momentum,dampening = self.get_settings('momentum','dampening', params=params, cls=NumberList)
185
193
  # abs,lerp,normalize_velocity = self.first_setting('abs','lerp','normalize_velocity', params=params)
186
194
  # velocity = self.get_state('velocity', params=params, cls=TensorList)