torchzero 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. docs/source/conf.py +6 -4
  2. docs/source/docstring template.py +46 -0
  3. tests/test_identical.py +2 -3
  4. tests/test_opts.py +64 -50
  5. tests/test_vars.py +1 -0
  6. torchzero/core/module.py +138 -6
  7. torchzero/core/transform.py +158 -51
  8. torchzero/modules/__init__.py +3 -2
  9. torchzero/modules/clipping/clipping.py +114 -17
  10. torchzero/modules/clipping/ema_clipping.py +27 -13
  11. torchzero/modules/clipping/growth_clipping.py +8 -7
  12. torchzero/modules/experimental/__init__.py +22 -5
  13. torchzero/modules/experimental/absoap.py +5 -2
  14. torchzero/modules/experimental/adadam.py +8 -2
  15. torchzero/modules/experimental/adamY.py +8 -2
  16. torchzero/modules/experimental/adam_lambertw.py +149 -0
  17. torchzero/modules/{line_search/trust_region.py → experimental/adaptive_step_size.py} +21 -4
  18. torchzero/modules/experimental/adasoap.py +7 -2
  19. torchzero/modules/experimental/cosine.py +214 -0
  20. torchzero/modules/experimental/cubic_adam.py +97 -0
  21. torchzero/modules/{projections → experimental}/dct.py +11 -11
  22. torchzero/modules/experimental/eigendescent.py +4 -1
  23. torchzero/modules/experimental/etf.py +32 -9
  24. torchzero/modules/experimental/exp_adam.py +113 -0
  25. torchzero/modules/experimental/expanded_lbfgs.py +141 -0
  26. torchzero/modules/{projections → experimental}/fft.py +10 -10
  27. torchzero/modules/experimental/hnewton.py +85 -0
  28. torchzero/modules/{quasi_newton/experimental → experimental}/modular_lbfgs.py +27 -28
  29. torchzero/modules/experimental/newtonnewton.py +7 -3
  30. torchzero/modules/experimental/parabolic_search.py +220 -0
  31. torchzero/modules/experimental/reduce_outward_lr.py +4 -4
  32. torchzero/modules/{projections/structural.py → experimental/structural_projections.py} +12 -54
  33. torchzero/modules/experimental/subspace_preconditioners.py +11 -4
  34. torchzero/modules/experimental/{tada.py → tensor_adagrad.py} +10 -6
  35. torchzero/modules/functional.py +12 -2
  36. torchzero/modules/grad_approximation/fdm.py +30 -3
  37. torchzero/modules/grad_approximation/forward_gradient.py +13 -3
  38. torchzero/modules/grad_approximation/grad_approximator.py +51 -6
  39. torchzero/modules/grad_approximation/rfdm.py +285 -38
  40. torchzero/modules/higher_order/higher_order_newton.py +152 -89
  41. torchzero/modules/line_search/__init__.py +4 -4
  42. torchzero/modules/line_search/adaptive.py +99 -0
  43. torchzero/modules/line_search/backtracking.py +34 -9
  44. torchzero/modules/line_search/line_search.py +70 -12
  45. torchzero/modules/line_search/polynomial.py +233 -0
  46. torchzero/modules/line_search/scipy.py +2 -2
  47. torchzero/modules/line_search/strong_wolfe.py +34 -7
  48. torchzero/modules/misc/__init__.py +27 -0
  49. torchzero/modules/{ops → misc}/debug.py +24 -1
  50. torchzero/modules/misc/escape.py +60 -0
  51. torchzero/modules/misc/gradient_accumulation.py +70 -0
  52. torchzero/modules/misc/misc.py +316 -0
  53. torchzero/modules/misc/multistep.py +158 -0
  54. torchzero/modules/misc/regularization.py +171 -0
  55. torchzero/modules/{ops → misc}/split.py +29 -1
  56. torchzero/modules/{ops → misc}/switch.py +44 -3
  57. torchzero/modules/momentum/__init__.py +1 -1
  58. torchzero/modules/momentum/averaging.py +6 -6
  59. torchzero/modules/momentum/cautious.py +45 -8
  60. torchzero/modules/momentum/ema.py +7 -7
  61. torchzero/modules/momentum/experimental.py +2 -2
  62. torchzero/modules/momentum/matrix_momentum.py +90 -63
  63. torchzero/modules/momentum/momentum.py +2 -1
  64. torchzero/modules/ops/__init__.py +3 -31
  65. torchzero/modules/ops/accumulate.py +6 -10
  66. torchzero/modules/ops/binary.py +72 -26
  67. torchzero/modules/ops/multi.py +77 -16
  68. torchzero/modules/ops/reduce.py +15 -7
  69. torchzero/modules/ops/unary.py +29 -13
  70. torchzero/modules/ops/utility.py +20 -12
  71. torchzero/modules/optimizers/__init__.py +12 -3
  72. torchzero/modules/optimizers/adagrad.py +23 -13
  73. torchzero/modules/optimizers/adahessian.py +223 -0
  74. torchzero/modules/optimizers/adam.py +7 -6
  75. torchzero/modules/optimizers/adan.py +110 -0
  76. torchzero/modules/optimizers/adaptive_heavyball.py +57 -0
  77. torchzero/modules/optimizers/esgd.py +171 -0
  78. torchzero/modules/{experimental/spectral.py → optimizers/ladagrad.py} +91 -71
  79. torchzero/modules/optimizers/lion.py +1 -1
  80. torchzero/modules/optimizers/mars.py +91 -0
  81. torchzero/modules/optimizers/msam.py +186 -0
  82. torchzero/modules/optimizers/muon.py +30 -5
  83. torchzero/modules/optimizers/orthograd.py +1 -1
  84. torchzero/modules/optimizers/rmsprop.py +7 -4
  85. torchzero/modules/optimizers/rprop.py +42 -8
  86. torchzero/modules/optimizers/sam.py +163 -0
  87. torchzero/modules/optimizers/shampoo.py +39 -5
  88. torchzero/modules/optimizers/soap.py +29 -19
  89. torchzero/modules/optimizers/sophia_h.py +71 -14
  90. torchzero/modules/projections/__init__.py +2 -4
  91. torchzero/modules/projections/cast.py +51 -0
  92. torchzero/modules/projections/galore.py +3 -1
  93. torchzero/modules/projections/projection.py +188 -94
  94. torchzero/modules/quasi_newton/__init__.py +12 -2
  95. torchzero/modules/quasi_newton/cg.py +160 -59
  96. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +163 -0
  97. torchzero/modules/quasi_newton/lbfgs.py +154 -97
  98. torchzero/modules/quasi_newton/lsr1.py +101 -57
  99. torchzero/modules/quasi_newton/quasi_newton.py +863 -215
  100. torchzero/modules/quasi_newton/trust_region.py +397 -0
  101. torchzero/modules/second_order/__init__.py +2 -2
  102. torchzero/modules/second_order/newton.py +220 -41
  103. torchzero/modules/second_order/newton_cg.py +300 -11
  104. torchzero/modules/second_order/nystrom.py +104 -1
  105. torchzero/modules/smoothing/gaussian.py +34 -0
  106. torchzero/modules/smoothing/laplacian.py +14 -4
  107. torchzero/modules/step_size/__init__.py +2 -0
  108. torchzero/modules/step_size/adaptive.py +122 -0
  109. torchzero/modules/step_size/lr.py +154 -0
  110. torchzero/modules/weight_decay/__init__.py +1 -1
  111. torchzero/modules/weight_decay/weight_decay.py +89 -7
  112. torchzero/modules/wrappers/optim_wrapper.py +29 -1
  113. torchzero/optim/wrappers/directsearch.py +39 -2
  114. torchzero/optim/wrappers/fcmaes.py +21 -13
  115. torchzero/optim/wrappers/mads.py +5 -6
  116. torchzero/optim/wrappers/nevergrad.py +16 -1
  117. torchzero/optim/wrappers/optuna.py +1 -1
  118. torchzero/optim/wrappers/scipy.py +5 -3
  119. torchzero/utils/__init__.py +2 -2
  120. torchzero/utils/derivatives.py +3 -3
  121. torchzero/utils/linalg/__init__.py +1 -1
  122. torchzero/utils/linalg/solve.py +251 -12
  123. torchzero/utils/numberlist.py +2 -0
  124. torchzero/utils/python_tools.py +10 -0
  125. torchzero/utils/tensorlist.py +40 -28
  126. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/METADATA +65 -40
  127. torchzero-0.3.11.dist-info/RECORD +159 -0
  128. torchzero/modules/experimental/diagonal_higher_order_newton.py +0 -225
  129. torchzero/modules/experimental/soapy.py +0 -163
  130. torchzero/modules/experimental/structured_newton.py +0 -111
  131. torchzero/modules/lr/__init__.py +0 -2
  132. torchzero/modules/lr/adaptive.py +0 -93
  133. torchzero/modules/lr/lr.py +0 -63
  134. torchzero/modules/ops/misc.py +0 -418
  135. torchzero/modules/quasi_newton/experimental/__init__.py +0 -1
  136. torchzero/modules/quasi_newton/olbfgs.py +0 -196
  137. torchzero-0.3.10.dist-info/RECORD +0 -139
  138. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/WHEEL +0 -0
  139. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/licenses/LICENSE +0 -0
  140. {torchzero-0.3.10.dist-info → torchzero-0.3.11.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,28 @@ from ...core import Chainable, Module
7
7
 
8
8
 
9
9
  class Alternate(Module):
10
- """alternate between stepping with `modules`"""
10
+ """Alternates between stepping with :code:`modules`.
11
+
12
+ That is, first step is performed with 1st module, second step with second module, etc.
13
+
14
+ Args:
15
+ steps (int | Iterable[int], optional): number of steps to perform with each module. Defaults to 1.
16
+
17
+ Examples:
18
+ Alternate between Adam, SignSGD and RMSprop
19
+
20
+ .. code-block:: python
21
+
22
+ opt = tz.Modular(
23
+ model.parameters(),
24
+ tz.m.Alternate(
25
+ tz.m.Adam(),
26
+ [tz.m.SignSGD(), tz.m.Mul(0.5)],
27
+ tz.m.RMSprop(),
28
+ ),
29
+ tz.m.LR(1e-3),
30
+ )
31
+ """
11
32
  LOOP = True
12
33
  def __init__(self, *modules: Chainable, steps: int | Iterable[int] = 1):
13
34
  if isinstance(steps, Iterable):
@@ -54,14 +75,34 @@ class Alternate(Module):
54
75
  return var
55
76
 
56
77
  class Switch(Alternate):
57
- """switch to next module after some steps"""
78
+ """After :code:`steps` steps switches to the next module.
79
+
80
+ Args:
81
+ steps (int | Iterable[int]): Number of steps to perform with each module.
82
+
83
+ Examples:
84
+ Start with Adam, switch to L-BFGS after 1000th step and Truncated Newton on 2000th step.
85
+
86
+ .. code-block:: python
87
+
88
+ opt = tz.Modular(
89
+ model.parameters(),
90
+ tz.m.Switch(
91
+ [tz.m.Adam(), tz.m.LR(1e-3)],
92
+ [tz.m.LBFGS(), tz.m.Backtracking()],
93
+ [tz.m.NewtonCG(maxiter=20), tz.m.Backtracking()],
94
+ steps = (1000, 2000)
95
+ )
96
+ )
97
+ """
98
+
58
99
  LOOP = False
59
100
  def __init__(self, *modules: Chainable, steps: int | Iterable[int]):
60
101
 
61
102
  if isinstance(steps, Iterable):
62
103
  steps = list(steps)
63
104
  if len(steps) != len(modules) - 1:
64
- raise ValueError(f"steps must be the same length as modules, got {len(modules) = }, {len(steps) = }")
105
+ raise ValueError(f"steps must be the same length as modules minus 1, got {len(modules) = }, {len(steps) = }")
65
106
 
66
107
  steps.append(1)
67
108
 
@@ -11,4 +11,4 @@ from .experimental import CoordinateMomentum
11
11
  # from .matrix_momentum import MatrixMomentum
12
12
 
13
13
  from .momentum import NAG, HeavyBall
14
- from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
14
+ from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
@@ -21,8 +21,8 @@ class Averaging(TensorwiseTransform):
21
21
  super().__init__(uses_grad=False, defaults=defaults, target=target)
22
22
 
23
23
  @torch.no_grad
24
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
25
- history_size = settings['history_size']
24
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
25
+ history_size = setting['history_size']
26
26
  if 'history' not in state:
27
27
  state['history'] = deque(maxlen=history_size)
28
28
  state['average'] = torch.zeros_like(tensor)
@@ -46,8 +46,8 @@ class WeightedAveraging(TensorwiseTransform):
46
46
  super().__init__(uses_grad=False, defaults=defaults, target=target)
47
47
 
48
48
  @torch.no_grad
49
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
50
- weights = settings['weights']
49
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
50
+ weights = setting['weights']
51
51
 
52
52
  if 'history' not in state:
53
53
  state['history'] = deque(maxlen=len(weights))
@@ -80,8 +80,8 @@ class MedianAveraging(TensorwiseTransform):
80
80
  super().__init__(uses_grad=False, defaults=defaults, target=target)
81
81
 
82
82
  @torch.no_grad
83
- def apply_tensor(self, tensor, param, grad, loss, state, settings):
84
- history_size = settings['history_size']
83
+ def apply_tensor(self, tensor, param, grad, loss, state, setting):
84
+ history_size = setting['history_size']
85
85
 
86
86
  if 'history' not in state:
87
87
  state['history'] = deque(maxlen=history_size)
@@ -55,9 +55,20 @@ class Cautious(Transform):
55
55
 
56
56
  "backtrack" - negate them (same as using update magnitude and gradient sign)
57
57
 
58
- reference
59
- *Cautious Optimizers: Improving Training with One Line of Code.
60
- Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu*
58
+ Examples:
59
+ Cautious Adam
60
+
61
+ .. code-block:: python
62
+
63
+ opt = tz.Modular(
64
+ bench.parameters(),
65
+ tz.m.Adam(),
66
+ tz.m.Cautious(),
67
+ tz.m.LR(1e-2)
68
+ )
69
+
70
+ References:
71
+ Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
61
72
  """
62
73
 
63
74
  def __init__(
@@ -70,7 +81,7 @@ class Cautious(Transform):
70
81
  super().__init__(defaults, uses_grad=True)
71
82
 
72
83
  @torch.no_grad
73
- def apply(self, tensors, params, grads, loss, states, settings):
84
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
74
85
  assert grads is not None
75
86
  mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(settings[0])
76
87
  return cautious_(TensorList(tensors), TensorList(grads), normalize=normalize, eps=eps, mode=mode)
@@ -89,7 +100,7 @@ class UpdateGradientSignConsistency(Transform):
89
100
  super().__init__(defaults, uses_grad=True)
90
101
 
91
102
  @torch.no_grad
92
- def apply(self, tensors, params, grads, loss, states, settings):
103
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
93
104
  assert grads is not None
94
105
  normalize, eps = itemgetter('normalize', 'eps')(settings[0])
95
106
 
@@ -159,6 +170,18 @@ class ScaleByGradCosineSimilarity(Transform):
159
170
 
160
171
  Args:
161
172
  eps (float, optional): epsilon for division. Defaults to 1e-6.
173
+
174
+ Examples:
175
+ Scaled Adam
176
+
177
+ .. code-block:: python
178
+
179
+ opt = tz.Modular(
180
+ bench.parameters(),
181
+ tz.m.Adam(),
182
+ tz.m.ScaleByGradCosineSimilarity(),
183
+ tz.m.LR(1e-2)
184
+ )
162
185
  """
163
186
  def __init__(
164
187
  self,
@@ -168,12 +191,12 @@ class ScaleByGradCosineSimilarity(Transform):
168
191
  super().__init__(defaults, uses_grad=True)
169
192
 
170
193
  @torch.no_grad
171
- def apply(self, tensors, params, grads, loss, states, settings):
194
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
172
195
  assert grads is not None
173
196
  eps = settings[0]['eps']
174
197
  tensors = TensorList(tensors)
175
198
  grads = TensorList(grads)
176
- cos_sim = (tensors.dot(grads)) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
199
+ cos_sim = tensors.dot(grads) / (tensors.global_vector_norm() * grads.global_vector_norm()).clip(min=eps)
177
200
 
178
201
  return tensors.mul_(cos_sim)
179
202
 
@@ -185,6 +208,20 @@ class ScaleModulesByCosineSimilarity(Module):
185
208
  main (Chainable): main module or sequence of modules whose update will be scaled.
186
209
  compare (Chainable): module or sequence of modules to compare to
187
210
  eps (float, optional): epsilon for division. Defaults to 1e-6.
211
+
212
+ Example:
213
+ Adam scaled by similarity to RMSprop
214
+
215
+ .. code-block:: python
216
+
217
+ opt = tz.Modular(
218
+ bench.parameters(),
219
+ tz.m.ScaleModulesByCosineSimilarity(
220
+ main = tz.m.Adam(),
221
+ compare = tz.m.RMSprop(0.999, debiased=True),
222
+ ),
223
+ tz.m.LR(1e-2)
224
+ )
188
225
  """
189
226
  def __init__(
190
227
  self,
@@ -213,7 +250,7 @@ class ScaleModulesByCosineSimilarity(Module):
213
250
  c = TensorList(compare_var.get_update())
214
251
  eps = self.settings[var.params[0]]['eps']
215
252
 
216
- cos_sim = (m.dot(c)) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
253
+ cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
217
254
 
218
255
  var.update = m.mul_(cos_sim)
219
256
  return var
@@ -25,7 +25,7 @@ class EMA(Transform):
25
25
  super().__init__(defaults, uses_grad=False, target=target)
26
26
 
27
27
  @torch.no_grad
28
- def apply(self, tensors, params, grads, loss, states, settings):
28
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
29
29
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
30
 
31
31
  debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
@@ -55,7 +55,7 @@ class EMASquared(Transform):
55
55
  super().__init__(defaults, uses_grad=False)
56
56
 
57
57
  @torch.no_grad
58
- def apply(self, tensors, params, grads, loss, states, settings):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
59
59
  amsgrad, pow = itemgetter('amsgrad', 'pow')(self.settings[params[0]])
60
60
  beta = NumberList(s['beta'] for s in settings)
61
61
 
@@ -83,7 +83,7 @@ class SqrtEMASquared(Transform):
83
83
 
84
84
 
85
85
  @torch.no_grad
86
- def apply(self, tensors, params, grads, loss, states, settings):
86
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
87
87
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
88
88
 
89
89
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
@@ -123,7 +123,7 @@ class Debias(Transform):
123
123
  super().__init__(defaults, uses_grad=False, target=target)
124
124
 
125
125
  @torch.no_grad
126
- def apply(self, tensors, params, grads, loss, states, settings):
126
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
127
127
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
128
128
 
129
129
  pow = settings[0]['pow']
@@ -145,7 +145,7 @@ class Debias2(Transform):
145
145
  super().__init__(defaults, uses_grad=False, target=target)
146
146
 
147
147
  @torch.no_grad
148
- def apply(self, tensors, params, grads, loss, states, settings):
148
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
149
149
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
150
150
 
151
151
  pow = settings[0]['pow']
@@ -166,7 +166,7 @@ class CenteredEMASquared(Transform):
166
166
  super().__init__(defaults, uses_grad=False)
167
167
 
168
168
  @torch.no_grad
169
- def apply(self, tensors, params, grads, loss, states, settings):
169
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
170
170
  amsgrad, pow = itemgetter('amsgrad', 'pow')(settings[0])
171
171
  beta = NumberList(s['beta'] for s in settings)
172
172
 
@@ -200,7 +200,7 @@ class CenteredSqrtEMASquared(Transform):
200
200
  super().__init__(defaults, uses_grad=False)
201
201
 
202
202
  @torch.no_grad
203
- def apply(self, tensors, params, grads, loss, states, settings):
203
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
204
204
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
205
205
 
206
206
  amsgrad, pow, debiased = itemgetter('amsgrad', 'pow', 'debiased')(settings[0])
@@ -49,7 +49,7 @@ class PrecenteredEMASquared(Transform):
49
49
  super().__init__(defaults, uses_grad=False, target=target)
50
50
 
51
51
  @torch.no_grad
52
- def apply(self, tensors, params, grads, loss, states, settings):
52
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
53
53
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
54
54
 
55
55
  beta1, beta2 = unpack_dicts(settings, 'beta1','beta2', cls=NumberList)
@@ -154,7 +154,7 @@ class CoordinateMomentum(Transform):
154
154
  super().__init__(defaults, uses_grad=False, target=target)
155
155
 
156
156
  @torch.no_grad
157
- def apply(self, tensors, params, grads, loss, states, settings):
157
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
158
158
  p = NumberList(s['p'] for s in settings)
159
159
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
160
160
  return coordinate_momentum_(TensorList(tensors), velocity_=velocity, p=p).clone()
@@ -7,18 +7,39 @@ from ...utils import NumberList, TensorList, as_tensorlist
7
7
  from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
8
 
9
9
  class MatrixMomentum(Module):
10
- """
11
- May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
12
- Evaluates hessian vector product on each step (via finite difference or autograd).
10
+ """Second order momentum method.
11
+
12
+ Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
13
+
14
+ .. note::
15
+ :code:`mu` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
16
+
17
+ .. note::
18
+ I have devised an adaptive version of this - :code:`tz.m.AdaptiveMatrixMomentum`, and it works well
19
+ without having to tune :code:`mu`.
13
20
 
14
- `mu` is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable.
21
+ .. note::
22
+ In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
23
+
24
+ .. note::
25
+ This module requires the a closure passed to the optimizer step,
26
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
27
+ The closure must accept a ``backward`` argument (refer to documentation).
15
28
 
16
29
  Args:
17
30
  mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
18
31
  beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
19
32
  hvp_method (str, optional):
20
- How to calculate hessian-vector products.
21
- Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
33
+ Determines how Hessian-vector products are evaluated.
34
+
35
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
36
+ This requires creating a graph for the gradient.
37
+ - ``"forward"``: Use a forward finite difference formula to
38
+ approximate the HVP. This requires one extra gradient evaluation.
39
+ - ``"central"``: Use a central finite difference formula for a
40
+ more accurate HVP approximation. This requires two extra
41
+ gradient evaluations.
42
+ Defaults to "autograd".
22
43
  h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
23
44
  hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
24
45
 
@@ -30,7 +51,7 @@ class MatrixMomentum(Module):
30
51
  self,
31
52
  mu=0.1,
32
53
  beta: float = 1,
33
- hvp_method: Literal["autograd", "forward", "central"] = "forward",
54
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
34
55
  h: float = 1e-3,
35
56
  hvp_tfm: Chainable | None = None,
36
57
  ):
@@ -40,57 +61,66 @@ class MatrixMomentum(Module):
40
61
  if hvp_tfm is not None:
41
62
  self.set_child('hvp_tfm', hvp_tfm)
42
63
 
64
+ def reset_for_online(self):
65
+ super().reset_for_online()
66
+ self.clear_state_keys('prev_update')
67
+
43
68
  @torch.no_grad
44
- def step(self, var):
69
+ def update(self, var):
45
70
  assert var.closure is not None
46
- prev_update = self.get_state(var.params, 'prev_update', cls=TensorList)
71
+ prev_update = self.get_state(var.params, 'prev_update')
47
72
  hvp_method = self.settings[var.params[0]]['hvp_method']
48
73
  h = self.settings[var.params[0]]['h']
49
74
 
50
- mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
51
-
52
- if hvp_method == 'autograd':
53
- with torch.enable_grad():
54
- grad = var.get_grad(create_graph=True)
55
- hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
56
-
57
- elif hvp_method == 'forward':
58
- var.get_grad()
59
- l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
60
- if var.loss_approx is None: var.loss_approx = l
75
+ Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
76
+ Hvp = [t.detach() for t in Hvp]
61
77
 
62
- elif hvp_method == 'central':
63
- l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
64
- if var.loss_approx is None: var.loss_approx = l
78
+ if 'hvp_tfm' in self.children:
79
+ Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
65
80
 
66
- else:
67
- raise ValueError(hvp_method)
81
+ self.store(var.params, "Hvp", Hvp)
68
82
 
69
- if 'hvp_tfm' in self.children:
70
- hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
71
83
 
84
+ @torch.no_grad
85
+ def apply(self, var):
72
86
  update = TensorList(var.get_update())
87
+ Hvp, prev_update = self.get_state(var.params, 'Hvp', 'prev_update', cls=TensorList)
88
+ mu,beta = self.get_settings(var.params, 'mu','beta', cls=NumberList)
73
89
 
74
- hvp_ = as_tensorlist(hvp_)
75
- update.add_(prev_update - hvp_*mu)
90
+ update.add_(prev_update - Hvp*mu)
76
91
  prev_update.set_(update * beta)
77
92
  var.update = update
78
93
  return var
79
94
 
80
95
 
81
96
  class AdaptiveMatrixMomentum(Module):
82
- """
83
- May be useful for ill conditioned stochastic quadratic objectives but I need to test this.
84
- Evaluates hessian vector product on each step (via finite difference or autograd).
97
+ """Second order momentum method.
98
+
99
+ Matrix momentum is useful for convex objectives, also for some reason it has very good generalization on elastic net logistic regression.
100
+
101
+ .. note::
102
+ In most cases MatrixMomentum should be the first module in the chain because it relies on autograd.
103
+
104
+ .. note::
105
+ This module requires the a closure passed to the optimizer step,
106
+ as it needs to re-evaluate the loss and gradients for calculating HVPs.
107
+ The closure must accept a ``backward`` argument (refer to documentation).
85
108
 
86
- This version estimates mu via a simple heuristic: ||s||/||y||, where s is parameter difference, y is gradient difference.
87
109
 
88
110
  Args:
89
111
  mu_mul (float, optional): multiplier to the estimated mu. Defaults to 1.
90
112
  beta (float, optional): decay for the buffer, this is not part of the original update rule. Defaults to 1.
91
113
  hvp_method (str, optional):
92
- How to calculate hessian-vector products.
93
- Exact - "autograd", or finite difference - "forward", "central". Defaults to 'forward'.
114
+ Determines how Hessian-vector products are evaluated.
115
+
116
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
117
+ This requires creating a graph for the gradient.
118
+ - ``"forward"``: Use a forward finite difference formula to
119
+ approximate the HVP. This requires one extra gradient evaluation.
120
+ - ``"central"``: Use a central finite difference formula for a
121
+ more accurate HVP approximation. This requires two extra
122
+ gradient evaluations.
123
+ Defaults to "autograd".
94
124
  h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
95
125
  hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
96
126
 
@@ -103,7 +133,7 @@ class AdaptiveMatrixMomentum(Module):
103
133
  mu_mul: float = 1,
104
134
  beta: float = 1,
105
135
  eps=1e-4,
106
- hvp_method: Literal["autograd", "forward", "central"] = "forward",
136
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
107
137
  h: float = 1e-3,
108
138
  hvp_tfm: Chainable | None = None,
109
139
  ):
@@ -113,8 +143,12 @@ class AdaptiveMatrixMomentum(Module):
113
143
  if hvp_tfm is not None:
114
144
  self.set_child('hvp_tfm', hvp_tfm)
115
145
 
146
+ def reset_for_online(self):
147
+ super().reset_for_online()
148
+ self.clear_state_keys('prev_params', 'prev_grad')
149
+
116
150
  @torch.no_grad
117
- def step(self, var):
151
+ def update(self, var):
118
152
  assert var.closure is not None
119
153
  prev_update, prev_params, prev_grad = self.get_state(var.params, 'prev_update', 'prev_params', 'prev_grad', cls=TensorList)
120
154
 
@@ -123,43 +157,36 @@ class AdaptiveMatrixMomentum(Module):
123
157
  h = settings['h']
124
158
  eps = settings['eps']
125
159
 
126
- mu_mul, beta = self.get_settings(var.params, 'mu_mul','beta', cls=NumberList)
127
-
128
- if hvp_method == 'autograd':
129
- with torch.enable_grad():
130
- grad = var.get_grad(create_graph=True)
131
- hvp_ = TensorList(hvp(var.params, grads=grad, vec=prev_update, allow_unused=True, retain_graph=False)).detach_()
132
-
133
- elif hvp_method == 'forward':
134
- var.get_grad()
135
- l, hvp_ = hvp_fd_forward(var.closure, var.params, vec=prev_update, g_0=var.grad, h=h, normalize=True)
136
- if var.loss_approx is None: var.loss_approx = l
137
-
138
- elif hvp_method == 'central':
139
- l, hvp_ = hvp_fd_central(var.closure, var.params, vec=prev_update, h=h, normalize=True)
140
- if var.loss_approx is None: var.loss_approx = l
160
+ mu_mul = NumberList(self.settings[p]['mu_mul'] for p in var.params)
141
161
 
142
- else:
143
- raise ValueError(hvp_method)
162
+ Hvp, _ = self.Hvp(prev_update, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
163
+ Hvp = [t.detach() for t in Hvp]
144
164
 
145
165
  if 'hvp_tfm' in self.children:
146
- hvp_ = TensorList(apply_transform(self.children['hvp_tfm'], hvp_, params=var.params, grads=var.grad, var=var))
166
+ Hvp = TensorList(apply_transform(self.children['hvp_tfm'], Hvp, params=var.params, grads=var.grad, var=var))
147
167
 
148
168
  # adaptive part
149
- update = TensorList(var.get_update())
150
-
151
169
  s_k = var.params - prev_params
152
170
  prev_params.copy_(var.params)
153
171
 
154
- assert var.grad is not None
155
- y_k = var.grad - prev_grad
156
- prev_grad.copy_(var.grad)
172
+ if hvp_method != 'central': assert var.grad is not None
173
+ grad = var.get_grad()
174
+ y_k = grad - prev_grad
175
+ prev_grad.copy_(grad)
157
176
 
158
177
  ada_mu = (s_k.global_vector_norm() / (y_k.global_vector_norm() + eps)) * mu_mul
159
178
 
160
- # matrix momentum uppdate
161
- hvp_ = as_tensorlist(hvp_)
162
- update.add_(prev_update - hvp_*ada_mu)
179
+ self.store(var.params, ['Hvp', 'ada_mu'], [Hvp, ada_mu])
180
+
181
+ @torch.no_grad
182
+ def apply(self, var):
183
+ Hvp, ada_mu = self.get_state(var.params, 'Hvp', 'ada_mu')
184
+ Hvp = as_tensorlist(Hvp)
185
+ beta = NumberList(self.settings[p]['beta'] for p in var.params)
186
+ update = TensorList(var.get_update())
187
+ prev_update = TensorList(self.state[p]['prev_update'] for p in var.params)
188
+
189
+ update.add_(prev_update - Hvp*ada_mu)
163
190
  prev_update.set_(update * beta)
164
191
  var.update = update
165
192
  return var
@@ -55,9 +55,10 @@ class NAG(Transform):
55
55
  super().__init__(defaults, uses_grad=False, target=target)
56
56
 
57
57
  @torch.no_grad
58
- def apply(self, tensors, params, grads, loss, states, settings):
58
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
59
59
  velocity = unpack_states(states, tensors, 'velocity', cls=TensorList)
60
60
  lerp = self.settings[params[0]]['lerp']
61
61
 
62
62
  momentum,dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
63
63
  return nag_(TensorList(tensors), velocity_=velocity,momentum=momentum,dampening=dampening,lerp=lerp)
64
+
@@ -7,7 +7,7 @@ from .accumulate import (
7
7
  )
8
8
  from .binary import (
9
9
  Add,
10
- BinaryOperation,
10
+ BinaryOperationBase,
11
11
  Clip,
12
12
  CopyMagnitude,
13
13
  CopySign,
@@ -27,37 +27,12 @@ from .binary import (
27
27
  Sub,
28
28
  Threshold,
29
29
  )
30
- from .debug import PrintShape, PrintUpdate
31
- from .misc import (
32
- DivByLoss,
33
- Dropout,
34
- FillLoss,
35
- GradientAccumulation,
36
- GradSign,
37
- GraftGradToUpdate,
38
- GraftToGrad,
39
- GraftToParams,
40
- LastAbsoluteRatio,
41
- LastDifference,
42
- LastGradDifference,
43
- LastProduct,
44
- LastRatio,
45
- MulByLoss,
46
- Multistep,
47
- NegateOnLossIncrease,
48
- NoiseSign,
49
- Previous,
50
- Relative,
51
- Sequential,
52
- UpdateSign,
53
- WeightDropout,
54
- )
55
30
  from .multi import (
56
31
  ClipModules,
57
32
  DivModules,
58
33
  GraftModules,
59
34
  LerpModules,
60
- MultiOperation,
35
+ MultiOperationBase,
61
36
  PowModules,
62
37
  SubModules,
63
38
  )
@@ -66,13 +41,11 @@ from .reduce import (
66
41
  Mean,
67
42
  MinimumModules,
68
43
  Prod,
69
- ReduceOperation,
44
+ ReduceOperationBase,
70
45
  Sum,
71
46
  WeightedMean,
72
47
  WeightedSum,
73
48
  )
74
- from .split import Split
75
- from .switch import Alternate, Switch
76
49
  from .unary import (
77
50
  Abs,
78
51
  CustomUnaryOperation,
@@ -97,7 +70,6 @@ from .utility import (
97
70
  Randn,
98
71
  RandomSample,
99
72
  Uniform,
100
- Update,
101
73
  UpdateToNone,
102
74
  Zeros,
103
75
  )
@@ -1,11 +1,7 @@
1
- from collections import deque
2
- from operator import itemgetter
3
- from typing import Literal
4
-
5
1
  import torch
6
2
 
7
3
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList, unpack_states, unpack_dicts
4
+ from ...utils import TensorList, unpack_states
9
5
 
10
6
  class AccumulateSum(Transform):
11
7
  """Accumulates sum of all past updates.
@@ -19,7 +15,7 @@ class AccumulateSum(Transform):
19
15
  super().__init__(defaults, uses_grad=False, target=target)
20
16
 
21
17
  @torch.no_grad
22
- def apply(self, tensors, params, grads, loss, states, settings):
18
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
23
19
  sum = unpack_states(states, tensors, 'sum', cls=TensorList)
24
20
  decay = [1-s['decay'] for s in settings]
25
21
  return sum.add_(tensors).lazy_mul(decay, clone=True)
@@ -36,7 +32,7 @@ class AccumulateMean(Transform):
36
32
  super().__init__(defaults, uses_grad=False, target=target)
37
33
 
38
34
  @torch.no_grad
39
- def apply(self, tensors, params, grads, loss, states, settings):
35
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
40
36
  step = self.global_state['step'] = self.global_state.get('step', 0) + 1
41
37
  mean = unpack_states(states, tensors, 'mean', cls=TensorList)
42
38
  decay = [1-s['decay'] for s in settings]
@@ -54,7 +50,7 @@ class AccumulateProduct(Transform):
54
50
  super().__init__(defaults, uses_grad=False, target=target)
55
51
 
56
52
  @torch.no_grad
57
- def apply(self, tensors, params, grads, loss, states, settings):
53
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
58
54
  prod = unpack_states(states, tensors, 'prod', cls=TensorList)
59
55
  decay = [1-s['decay'] for s in settings]
60
56
  return prod.mul_(tensors).lazy_mul(decay, clone=True)
@@ -71,7 +67,7 @@ class AccumulateMaximum(Transform):
71
67
  super().__init__(defaults, uses_grad=False, target=target)
72
68
 
73
69
  @torch.no_grad
74
- def apply(self, tensors, params, grads, loss, states, settings):
70
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
75
71
  maximum = unpack_states(states, tensors, 'maximum', cls=TensorList)
76
72
  decay = [1-s['decay'] for s in settings]
77
73
  return maximum.maximum_(tensors).lazy_mul(decay, clone=True)
@@ -88,7 +84,7 @@ class AccumulateMinimum(Transform):
88
84
  super().__init__(defaults, uses_grad=False, target=target)
89
85
 
90
86
  @torch.no_grad
91
- def apply(self, tensors, params, grads, loss, states, settings):
87
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
92
88
  minimum = unpack_states(states, tensors, 'minimum', cls=TensorList)
93
89
  decay = [1-s['decay'] for s in settings]
94
90
  return minimum.minimum_(tensors).lazy_mul(decay, clone=True)