torchzero 0.3.11__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. tests/test_opts.py +95 -76
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +229 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/spsa1.py +93 -0
  42. torchzero/modules/experimental/structural_projections.py +1 -1
  43. torchzero/modules/functional.py +50 -14
  44. torchzero/modules/grad_approximation/__init__.py +1 -1
  45. torchzero/modules/grad_approximation/fdm.py +19 -20
  46. torchzero/modules/grad_approximation/forward_gradient.py +6 -7
  47. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  48. torchzero/modules/grad_approximation/rfdm.py +114 -175
  49. torchzero/modules/higher_order/__init__.py +1 -1
  50. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  51. torchzero/modules/least_squares/__init__.py +1 -0
  52. torchzero/modules/least_squares/gn.py +161 -0
  53. torchzero/modules/line_search/__init__.py +2 -2
  54. torchzero/modules/line_search/_polyinterp.py +289 -0
  55. torchzero/modules/line_search/adaptive.py +69 -44
  56. torchzero/modules/line_search/backtracking.py +83 -70
  57. torchzero/modules/line_search/line_search.py +159 -68
  58. torchzero/modules/line_search/scipy.py +16 -4
  59. torchzero/modules/line_search/strong_wolfe.py +319 -220
  60. torchzero/modules/misc/__init__.py +8 -0
  61. torchzero/modules/misc/debug.py +4 -4
  62. torchzero/modules/misc/escape.py +9 -7
  63. torchzero/modules/misc/gradient_accumulation.py +88 -22
  64. torchzero/modules/misc/homotopy.py +59 -0
  65. torchzero/modules/misc/misc.py +82 -15
  66. torchzero/modules/misc/multistep.py +47 -11
  67. torchzero/modules/misc/regularization.py +5 -9
  68. torchzero/modules/misc/split.py +55 -35
  69. torchzero/modules/misc/switch.py +1 -1
  70. torchzero/modules/momentum/__init__.py +1 -5
  71. torchzero/modules/momentum/averaging.py +3 -3
  72. torchzero/modules/momentum/cautious.py +42 -47
  73. torchzero/modules/momentum/momentum.py +35 -1
  74. torchzero/modules/ops/__init__.py +9 -1
  75. torchzero/modules/ops/binary.py +9 -8
  76. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  77. torchzero/modules/ops/multi.py +15 -15
  78. torchzero/modules/ops/reduce.py +1 -1
  79. torchzero/modules/ops/utility.py +12 -8
  80. torchzero/modules/projections/projection.py +4 -4
  81. torchzero/modules/quasi_newton/__init__.py +1 -16
  82. torchzero/modules/quasi_newton/damping.py +105 -0
  83. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  84. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  85. torchzero/modules/quasi_newton/lsr1.py +167 -132
  86. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  87. torchzero/modules/restarts/__init__.py +7 -0
  88. torchzero/modules/restarts/restars.py +253 -0
  89. torchzero/modules/second_order/__init__.py +2 -1
  90. torchzero/modules/second_order/multipoint.py +238 -0
  91. torchzero/modules/second_order/newton.py +133 -88
  92. torchzero/modules/second_order/newton_cg.py +207 -170
  93. torchzero/modules/smoothing/__init__.py +1 -1
  94. torchzero/modules/smoothing/sampling.py +300 -0
  95. torchzero/modules/step_size/__init__.py +1 -1
  96. torchzero/modules/step_size/adaptive.py +312 -47
  97. torchzero/modules/termination/__init__.py +14 -0
  98. torchzero/modules/termination/termination.py +207 -0
  99. torchzero/modules/trust_region/__init__.py +5 -0
  100. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  101. torchzero/modules/trust_region/dogleg.py +92 -0
  102. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  103. torchzero/modules/trust_region/trust_cg.py +99 -0
  104. torchzero/modules/trust_region/trust_region.py +350 -0
  105. torchzero/modules/variance_reduction/__init__.py +1 -0
  106. torchzero/modules/variance_reduction/svrg.py +208 -0
  107. torchzero/modules/weight_decay/weight_decay.py +65 -64
  108. torchzero/modules/zeroth_order/__init__.py +1 -0
  109. torchzero/modules/zeroth_order/cd.py +122 -0
  110. torchzero/optim/root.py +65 -0
  111. torchzero/optim/utility/split.py +8 -8
  112. torchzero/optim/wrappers/directsearch.py +0 -1
  113. torchzero/optim/wrappers/fcmaes.py +3 -2
  114. torchzero/optim/wrappers/nlopt.py +0 -2
  115. torchzero/optim/wrappers/optuna.py +2 -2
  116. torchzero/optim/wrappers/scipy.py +81 -22
  117. torchzero/utils/__init__.py +40 -4
  118. torchzero/utils/compile.py +1 -1
  119. torchzero/utils/derivatives.py +123 -111
  120. torchzero/utils/linalg/__init__.py +9 -2
  121. torchzero/utils/linalg/linear_operator.py +329 -0
  122. torchzero/utils/linalg/matrix_funcs.py +2 -2
  123. torchzero/utils/linalg/orthogonalize.py +2 -1
  124. torchzero/utils/linalg/qr.py +2 -2
  125. torchzero/utils/linalg/solve.py +226 -154
  126. torchzero/utils/metrics.py +83 -0
  127. torchzero/utils/optimizer.py +2 -2
  128. torchzero/utils/python_tools.py +7 -0
  129. torchzero/utils/tensorlist.py +105 -34
  130. torchzero/utils/torch_tools.py +9 -4
  131. torchzero-0.3.14.dist-info/METADATA +14 -0
  132. torchzero-0.3.14.dist-info/RECORD +167 -0
  133. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/top_level.txt +0 -1
  134. docs/source/conf.py +0 -59
  135. docs/source/docstring template.py +0 -46
  136. torchzero/modules/experimental/absoap.py +0 -253
  137. torchzero/modules/experimental/adadam.py +0 -118
  138. torchzero/modules/experimental/adamY.py +0 -131
  139. torchzero/modules/experimental/adam_lambertw.py +0 -149
  140. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  141. torchzero/modules/experimental/adasoap.py +0 -177
  142. torchzero/modules/experimental/cosine.py +0 -214
  143. torchzero/modules/experimental/cubic_adam.py +0 -97
  144. torchzero/modules/experimental/eigendescent.py +0 -120
  145. torchzero/modules/experimental/etf.py +0 -195
  146. torchzero/modules/experimental/exp_adam.py +0 -113
  147. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  148. torchzero/modules/experimental/hnewton.py +0 -85
  149. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  150. torchzero/modules/experimental/parabolic_search.py +0 -220
  151. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  152. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  153. torchzero/modules/line_search/polynomial.py +0 -233
  154. torchzero/modules/momentum/matrix_momentum.py +0 -193
  155. torchzero/modules/optimizers/adagrad.py +0 -165
  156. torchzero/modules/quasi_newton/trust_region.py +0 -397
  157. torchzero/modules/smoothing/gaussian.py +0 -198
  158. torchzero-0.3.11.dist-info/METADATA +0 -404
  159. torchzero-0.3.11.dist-info/RECORD +0 -159
  160. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  161. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  162. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  163. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  164. {torchzero-0.3.11.dist-info → torchzero-0.3.14.dist-info}/WHEEL +0 -0
@@ -53,7 +53,7 @@ class Alternate(Module):
53
53
  var = module.step(var.clone(clone_update=False))
54
54
 
55
55
  # number of steps until next module
56
- steps = self.settings[var.params[0]]['steps']
56
+ steps = self.defaults['steps']
57
57
  if isinstance(steps, int): steps = [steps]*len(self.children)
58
58
 
59
59
  if 'steps_to_next' not in self.global_state:
@@ -6,9 +6,5 @@ from .cautious import (
6
6
  ScaleModulesByCosineSimilarity,
7
7
  UpdateGradientSignConsistency,
8
8
  )
9
- from .ema import EMA, Debias, Debias2, EMASquared, SqrtEMASquared, CenteredEMASquared, CenteredSqrtEMASquared
10
- from .experimental import CoordinateMomentum
11
- # from .matrix_momentum import MatrixMomentum
12
9
 
13
- from .momentum import NAG, HeavyBall
14
- from .matrix_momentum import MatrixMomentum, AdaptiveMatrixMomentum
10
+ from .momentum import NAG, HeavyBall, EMA
@@ -10,7 +10,7 @@ from ...utils import tolist
10
10
 
11
11
 
12
12
  class Averaging(TensorwiseTransform):
13
- """Average of past :code:`history_size` updates.
13
+ """Average of past ``history_size`` updates.
14
14
 
15
15
  Args:
16
16
  history_size (int): Number of past updates to average
@@ -35,7 +35,7 @@ class Averaging(TensorwiseTransform):
35
35
  return average / len(history)
36
36
 
37
37
  class WeightedAveraging(TensorwiseTransform):
38
- """Weighted average of past :code:`len(weights)` updates.
38
+ """Weighted average of past ``len(weights)`` updates.
39
39
 
40
40
  Args:
41
41
  weights (Sequence[float]): a sequence of weights from oldest to newest.
@@ -69,7 +69,7 @@ class WeightedAveraging(TensorwiseTransform):
69
69
 
70
70
 
71
71
  class MedianAveraging(TensorwiseTransform):
72
- """Median of past :code:`history_size` updates.
72
+ """Median of past ``history_size`` updates.
73
73
 
74
74
  Args:
75
75
  history_size (int): Number of past updates to average
@@ -48,24 +48,22 @@ class Cautious(Transform):
48
48
  eps (float, optional): epsilon for normalization. Defaults to 1e-6.
49
49
  mode (str, optional):
50
50
  what to do with updates with inconsistent signs.
51
+ - "zero" - set them to zero (as in paper)
52
+ - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
53
+ - "backtrack" - negate them
51
54
 
52
- "zero" - set them to zero (as in paper)
55
+ ## Examples:
53
56
 
54
- "grad" - set them to the gradient
57
+ Cautious Adam
55
58
 
56
- "backtrack" - negate them (same as using update magnitude and gradient sign)
57
-
58
- Examples:
59
- Cautious Adam
60
-
61
- .. code-block:: python
62
-
63
- opt = tz.Modular(
64
- bench.parameters(),
65
- tz.m.Adam(),
66
- tz.m.Cautious(),
67
- tz.m.LR(1e-2)
68
- )
59
+ ```python
60
+ opt = tz.Modular(
61
+ bench.parameters(),
62
+ tz.m.Adam(),
63
+ tz.m.Cautious(),
64
+ tz.m.LR(1e-2)
65
+ )
66
+ ```
69
67
 
70
68
  References:
71
69
  Cautious Optimizers: Improving Training with One Line of Code. Kaizhao Liang, Lizhang Chen, Bo Liu, Qiang Liu
@@ -120,12 +118,9 @@ class IntermoduleCautious(Module):
120
118
  eps (float, optional): epsilon for normalization. Defaults to 1e-6.
121
119
  mode (str, optional):
122
120
  what to do with updates with inconsistent signs.
123
-
124
- "zero" - set them to zero (as in paper)
125
-
126
- "grad" - set them to the gradient
127
-
128
- "backtrack" - negate them (same as using update magnitude and gradient sign)
121
+ - "zero" - set them to zero (as in paper)
122
+ - "grad" - set them to the gradient (same as using update magnitude and gradient sign)
123
+ - "backtrack" - negate them
129
124
  """
130
125
  def __init__(
131
126
  self,
@@ -153,7 +148,7 @@ class IntermoduleCautious(Module):
153
148
  compare_var = compare.step(var.clone(clone_update=True))
154
149
  var.update_attrs_from_clone_(compare_var)
155
150
 
156
- mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.settings[var.params[0]])
151
+ mode, normalize, eps = itemgetter('mode', 'normalize', 'eps')(self.defaults)
157
152
  var.update = cautious_(
158
153
  TensorList(main_var.get_update()),
159
154
  TensorList(compare_var.get_update()),
@@ -171,17 +166,17 @@ class ScaleByGradCosineSimilarity(Transform):
171
166
  Args:
172
167
  eps (float, optional): epsilon for division. Defaults to 1e-6.
173
168
 
174
- Examples:
175
- Scaled Adam
176
-
177
- .. code-block:: python
178
-
179
- opt = tz.Modular(
180
- bench.parameters(),
181
- tz.m.Adam(),
182
- tz.m.ScaleByGradCosineSimilarity(),
183
- tz.m.LR(1e-2)
184
- )
169
+ ## Examples:
170
+
171
+ Scaled Adam
172
+ ```python
173
+ opt = tz.Modular(
174
+ bench.parameters(),
175
+ tz.m.Adam(),
176
+ tz.m.ScaleByGradCosineSimilarity(),
177
+ tz.m.LR(1e-2)
178
+ )
179
+ ```
185
180
  """
186
181
  def __init__(
187
182
  self,
@@ -209,19 +204,19 @@ class ScaleModulesByCosineSimilarity(Module):
209
204
  compare (Chainable): module or sequence of modules to compare to
210
205
  eps (float, optional): epsilon for division. Defaults to 1e-6.
211
206
 
212
- Example:
213
- Adam scaled by similarity to RMSprop
214
-
215
- .. code-block:: python
216
-
217
- opt = tz.Modular(
218
- bench.parameters(),
219
- tz.m.ScaleModulesByCosineSimilarity(
220
- main = tz.m.Adam(),
221
- compare = tz.m.RMSprop(0.999, debiased=True),
222
- ),
223
- tz.m.LR(1e-2)
224
- )
207
+ ## Examples:
208
+
209
+ Adam scaled by similarity to RMSprop
210
+ ```python
211
+ opt = tz.Modular(
212
+ bench.parameters(),
213
+ tz.m.ScaleModulesByCosineSimilarity(
214
+ main = tz.m.Adam(),
215
+ compare = tz.m.RMSprop(0.999, debiased=True),
216
+ ),
217
+ tz.m.LR(1e-2)
218
+ )
219
+ ```
225
220
  """
226
221
  def __init__(
227
222
  self,
@@ -248,7 +243,7 @@ class ScaleModulesByCosineSimilarity(Module):
248
243
 
249
244
  m = TensorList(main_var.get_update())
250
245
  c = TensorList(compare_var.get_update())
251
- eps = self.settings[var.params[0]]['eps']
246
+ eps = self.defaults['eps']
252
247
 
253
248
  cos_sim = m.dot(c) / (m.global_vector_norm() * c.global_vector_norm()).clip(min=eps)
254
249
 
@@ -1,10 +1,44 @@
1
+ from collections import deque
2
+ from operator import itemgetter
1
3
  from typing import Literal
2
4
 
3
5
  import torch
4
6
 
5
7
  from ...core import Target, Transform
6
8
  from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
7
- from .ema import EMA
9
+ from ..functional import debias, ema_
10
+
11
+
12
+ class EMA(Transform):
13
+ """Maintains an exponential moving average of update.
14
+
15
+ Args:
16
+ momentum (float, optional): momentum (beta). Defaults to 0.9.
17
+ dampening (float, optional): momentum dampening. Defaults to 0.
18
+ debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
+ lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
+ ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
+ target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
+ """
23
+ def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
24
+ defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
+ super().__init__(defaults, uses_grad=False, target=target)
26
+
27
+ @torch.no_grad
28
+ def apply_tensors(self, tensors, params, grads, loss, states, settings):
29
+ step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
+
31
+ debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
32
+
33
+ exp_avg = unpack_states(states, tensors, 'exp_avg',
34
+ init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
+ momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
36
+
37
+ exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
38
+
39
+ if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
40
+ else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
41
+
8
42
 
9
43
 
10
44
  class HeavyBall(EMA):
@@ -27,6 +27,14 @@ from .binary import (
27
27
  Sub,
28
28
  Threshold,
29
29
  )
30
+ from .higher_level import (
31
+ CenteredEMASquared,
32
+ CenteredSqrtEMASquared,
33
+ Debias,
34
+ Debias2,
35
+ EMASquared,
36
+ SqrtEMASquared,
37
+ )
30
38
  from .multi import (
31
39
  ClipModules,
32
40
  DivModules,
@@ -64,7 +72,7 @@ from .utility import (
64
72
  Grad,
65
73
  GradToNone,
66
74
  Identity,
67
- NoOp,
75
+ Noop,
68
76
  Ones,
69
77
  Params,
70
78
  Randn,
@@ -57,8 +57,8 @@ class Add(BinaryOperationBase):
57
57
 
58
58
  @torch.no_grad
59
59
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
60
- if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.settings[var.params[0]]['alpha'])
61
- else: torch._foreach_add_(update, other, alpha=self.settings[var.params[0]]['alpha'])
60
+ if isinstance(other, (int,float)): torch._foreach_add_(update, other * self.defaults['alpha'])
61
+ else: torch._foreach_add_(update, other, alpha=self.defaults['alpha'])
62
62
  return update
63
63
 
64
64
  class Sub(BinaryOperationBase):
@@ -72,8 +72,8 @@ class Sub(BinaryOperationBase):
72
72
 
73
73
  @torch.no_grad
74
74
  def transform(self, var, update: list[torch.Tensor], other: float | list[torch.Tensor]):
75
- if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.settings[var.params[0]]['alpha'])
76
- else: torch._foreach_sub_(update, other, alpha=self.settings[var.params[0]]['alpha'])
75
+ if isinstance(other, (int,float)): torch._foreach_sub_(update, other * self.defaults['alpha'])
76
+ else: torch._foreach_sub_(update, other, alpha=self.defaults['alpha'])
77
77
  return update
78
78
 
79
79
  class RSub(BinaryOperationBase):
@@ -219,7 +219,7 @@ class Graft(BinaryOperationBase):
219
219
 
220
220
  @torch.no_grad
221
221
  def transform(self, var, update: list[torch.Tensor], magnitude: list[torch.Tensor]):
222
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
222
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
223
223
  return TensorList(update).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps)
224
224
 
225
225
  class RGraft(BinaryOperationBase):
@@ -231,7 +231,7 @@ class RGraft(BinaryOperationBase):
231
231
 
232
232
  @torch.no_grad
233
233
  def transform(self, var, update: list[torch.Tensor], direction: list[torch.Tensor]):
234
- tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.settings[var.params[0]])
234
+ tensorwise, ord, eps = itemgetter('tensorwise','ord','eps')(self.defaults)
235
235
  return TensorList(direction).graft_(update, tensorwise=tensorwise, ord=ord, eps=eps)
236
236
 
237
237
  GraftToUpdate = RGraft
@@ -265,7 +265,8 @@ class GramSchimdt(BinaryOperationBase):
265
265
  @torch.no_grad
266
266
  def transform(self, var, update: list[torch.Tensor], other: list[torch.Tensor]):
267
267
  update = TensorList(update); other = TensorList(other)
268
- return update - (other*update) / ((other*other) + 1e-8)
268
+ min = torch.finfo(update[0].dtype).tiny * 2
269
+ return update - (other*update) / (other*other).clip(min=min)
269
270
 
270
271
 
271
272
  class Threshold(BinaryOperationBase):
@@ -276,7 +277,7 @@ class Threshold(BinaryOperationBase):
276
277
 
277
278
  @torch.no_grad
278
279
  def transform(self, var, update: list[torch.Tensor], threshold: list[torch.Tensor] | float, value: list[torch.Tensor] | float):
279
- update_above = self.settings[var.params[0]]['update_above']
280
+ update_above = self.defaults['update_above']
280
281
  update = TensorList(update)
281
282
  if update_above:
282
283
  if isinstance(value, list): return update.where_(update>threshold, value)
@@ -5,39 +5,16 @@ from typing import Literal
5
5
  import torch
6
6
 
7
7
  from ...core import Target, Transform
8
- from ...utils import TensorList, NumberList, unpack_dicts, unpack_states
9
- from ..functional import debias, ema_, ema_sq_, sqrt_ema_sq_, centered_ema_sq_, sqrt_centered_ema_sq_, debias_second_momentum
10
-
11
-
12
- class EMA(Transform):
13
- """Maintains an exponential moving average of update.
14
-
15
- Args:
16
- momentum (float, optional): momentum (beta). Defaults to 0.9.
17
- dampening (float, optional): momentum dampening. Defaults to 0.
18
- debiased (bool, optional): whether to debias the EMA like in Adam. Defaults to False.
19
- lerp (bool, optional): whether to use linear interpolation. Defaults to True.
20
- ema_init (str, optional): initial values for the EMA, "zeros" or "update".
21
- target (Target, optional): target to apply EMA to. Defaults to 'update'.
22
- """
23
- def __init__(self, momentum:float=0.9, dampening:float=0, debiased: bool = False, lerp=True, ema_init: Literal['zeros', 'update'] = 'zeros', target: Target = 'update'):
24
- defaults = dict(momentum=momentum,dampening=dampening,debiased=debiased,lerp=lerp,ema_init=ema_init)
25
- super().__init__(defaults, uses_grad=False, target=target)
26
-
27
- @torch.no_grad
28
- def apply_tensors(self, tensors, params, grads, loss, states, settings):
29
- step = self.global_state['step'] = self.global_state.get('step', 0) + 1
30
-
31
- debiased, lerp, ema_init = itemgetter('debiased','lerp','ema_init')(settings[0])
32
-
33
- exp_avg = unpack_states(states, tensors, 'exp_avg',
34
- init=torch.zeros_like if ema_init=='zeros' else tensors, cls=TensorList)
35
- momentum, dampening = unpack_dicts(settings, 'momentum','dampening', cls=NumberList)
36
-
37
- exp_avg = ema_(TensorList(tensors), exp_avg_=exp_avg,beta=momentum,dampening=dampening,lerp=lerp)
38
-
39
- if debiased: return debias(exp_avg, step=step, beta1=momentum, alpha=1, inplace=False)
40
- else: return exp_avg.clone() # this has exp_avg storage so needs to be cloned
8
+ from ...utils import NumberList, TensorList, unpack_dicts, unpack_states
9
+ from ..functional import (
10
+ centered_ema_sq_,
11
+ debias,
12
+ debias_second_momentum,
13
+ ema_,
14
+ ema_sq_,
15
+ sqrt_centered_ema_sq_,
16
+ sqrt_ema_sq_,
17
+ )
41
18
 
42
19
 
43
20
  class EMASquared(Transform):
@@ -8,7 +8,7 @@ from typing import Any, Literal
8
8
  import torch
9
9
 
10
10
  from ...core import Chainable, Module, Target, Var, maybe_chain
11
- from ...utils import TensorList, tensorlist
11
+ from ...utils import TensorList, tensorlist, Metrics
12
12
 
13
13
 
14
14
  class MultiOperationBase(Module, ABC):
@@ -59,7 +59,7 @@ class SubModules(MultiOperationBase):
59
59
 
60
60
  @torch.no_grad
61
61
  def transform(self, var: Var, input: float | list[torch.Tensor], other: float | list[torch.Tensor]) -> list[torch.Tensor]:
62
- alpha = self.settings[var.params[0]]['alpha']
62
+ alpha = self.defaults['alpha']
63
63
 
64
64
  if isinstance(input, (int,float)):
65
65
  assert isinstance(other, list)
@@ -112,7 +112,7 @@ class LerpModules(MultiOperationBase):
112
112
 
113
113
  @torch.no_grad
114
114
  def transform(self, var: Var, input: list[torch.Tensor], end: list[torch.Tensor]) -> list[torch.Tensor]:
115
- torch._foreach_lerp_(input, end, weight=self.settings[var.params[0]]['weight'])
115
+ torch._foreach_lerp_(input, end, weight=self.defaults['weight'])
116
116
  return input
117
117
 
118
118
  class ClipModules(MultiOperationBase):
@@ -154,45 +154,45 @@ class GraftModules(MultiOperationBase):
154
154
  Reference:
155
155
  Agarwal, N., Anil, R., Hazan, E., Koren, T., & Zhang, C. (2020). Disentangling adaptive gradient methods from learning rates. arXiv preprint arXiv:2002.11803. https://arxiv.org/pdf/2002.11803
156
156
  """
157
- def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:float=2, eps:float = 1e-6, strength:float=1):
157
+ def __init__(self, direction: Chainable, magnitude: Chainable, tensorwise:bool=True, ord:Metrics=2, eps:float = 1e-6, strength:float=1):
158
158
  defaults = dict(tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
159
159
  super().__init__(defaults, direction=direction, magnitude=magnitude)
160
160
 
161
161
  @torch.no_grad
162
162
  def transform(self, var, magnitude: list[torch.Tensor], direction:list[torch.Tensor]):
163
- tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.settings[var.params[0]])
163
+ tensorwise, ord, eps, strength = itemgetter('tensorwise','ord','eps', 'strength')(self.defaults)
164
164
  return TensorList(direction).graft_(magnitude, tensorwise=tensorwise, ord=ord, eps=eps, strength=strength)
165
165
 
166
166
  class MultiplyByModuleNorm(MultiOperationBase):
167
167
  """Outputs :code:`input` multiplied by norm of the :code:`norm` output."""
168
- def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
168
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
169
169
  defaults = dict(tensorwise=tensorwise, ord=ord)
170
170
  super().__init__(defaults, input=input, norm=norm)
171
171
 
172
172
  @torch.no_grad
173
173
  def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
174
- tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
174
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
175
175
  if tensorwise:
176
- if ord == 'mean_abs': n = [t.mean() for t in torch._foreach_abs(norm)]
177
- else: n = torch._foreach_norm(norm, ord)
178
- else: n = TensorList(norm).global_vector_norm(ord)
176
+ n = TensorList(norm).metric(ord)
177
+ else:
178
+ n = TensorList(norm).global_metric(ord)
179
179
 
180
180
  torch._foreach_mul_(input, n)
181
181
  return input
182
182
 
183
183
  class DivideByModuleNorm(MultiOperationBase):
184
184
  """Outputs :code:`input` divided by norm of the :code:`norm` output."""
185
- def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:float|Literal['mean_abs']=2):
185
+ def __init__(self, input: Chainable, norm: Chainable, tensorwise:bool=True, ord:Metrics=2):
186
186
  defaults = dict(tensorwise=tensorwise, ord=ord)
187
187
  super().__init__(defaults, input=input, norm=norm)
188
188
 
189
189
  @torch.no_grad
190
190
  def transform(self, var, input: list[torch.Tensor], norm:list[torch.Tensor]):
191
- tensorwise, ord = itemgetter('tensorwise','ord')(self.settings[var.params[0]])
191
+ tensorwise, ord = itemgetter('tensorwise','ord')(self.defaults)
192
192
  if tensorwise:
193
- if ord == 'mean_abs': n = [t.mean().clip(min=1e-8) for t in torch._foreach_abs(norm)]
194
- else: n = torch._foreach_clamp_min(torch._foreach_norm(norm, ord), 1e-8)
195
- else: n = TensorList(norm).global_vector_norm(ord).clip(min=1e-8)
193
+ n = TensorList(norm).metric(ord)
194
+ else:
195
+ n = TensorList(norm).global_metric(ord)
196
196
 
197
197
  torch._foreach_div_(input, n)
198
198
  return input
@@ -81,7 +81,7 @@ class WeightedSum(ReduceOperationBase):
81
81
  @torch.no_grad
82
82
  def transform(self, var: Var, *inputs: float | list[torch.Tensor]) -> list[torch.Tensor]:
83
83
  sorted_inputs = sorted(inputs, key=lambda x: isinstance(x, float))
84
- weights = self.settings[var.params[0]]['weights']
84
+ weights = self.defaults['weights']
85
85
  sum = cast(list, sorted_inputs[0])
86
86
  torch._foreach_mul_(sum, weights[0])
87
87
  if len(sorted_inputs) > 1:
@@ -4,7 +4,7 @@ import torch
4
4
 
5
5
  from ...core import Module, Target, Transform
6
6
  from ...utils.tensorlist import Distributions, TensorList
7
-
7
+ from ...utils.linalg.linear_operator import ScaledIdentity
8
8
 
9
9
  class Clone(Module):
10
10
  """Clones input. May be useful to store some intermediate result and make sure it doesn't get affected by in-place operations"""
@@ -64,15 +64,15 @@ class Fill(Module):
64
64
 
65
65
  class RandomSample(Module):
66
66
  """Outputs tensors filled with random numbers from distribution depending on value of :code:`distribution`."""
67
- def __init__(self, eps: float = 1, distribution: Distributions = 'normal'):
68
- defaults = dict(eps=eps, distribution=distribution)
67
+ def __init__(self, distribution: Distributions = 'normal', variance:float | None = None):
68
+ defaults = dict(distribution=distribution, variance=variance)
69
69
  super().__init__(defaults)
70
70
 
71
71
  @torch.no_grad
72
72
  def step(self, var):
73
- var.update = TensorList(var.params).sample_like(
74
- eps=[self.settings[p]['eps'] for p in var.params], distribution=self.settings[var.params[0]]['distribution']
75
- )
73
+ distribution = self.defaults['distribution']
74
+ variance = self.get_settings(var.params, 'variance')
75
+ var.update = TensorList(var.params).sample_like(distribution=distribution, variance=variance)
76
76
  return var
77
77
 
78
78
  class Randn(Module):
@@ -112,9 +112,13 @@ class UpdateToNone(Module):
112
112
  return var
113
113
 
114
114
  class Identity(Module):
115
- """A placeholder identity operator that is argument-insensitive."""
115
+ """Identity operator that is argument-insensitive. This also can be used as identity hessian for trust region methods."""
116
116
  def __init__(self, *args, **kwargs): super().__init__()
117
117
  def step(self, var): return var
118
+ def get_H(self, var):
119
+ n = sum(p.numel() for p in var.params)
120
+ p = var.params[0]
121
+ return ScaledIdentity(shape=(n,n), device=p.device, dtype=p.dtype)
118
122
 
119
- NoOp = Identity
123
+ Noop = Identity
120
124
  """A placeholder identity operator that is argument-insensitive."""
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  import warnings
3
3
  from abc import ABC, abstractmethod
4
- from collections import defaultdict, ChainMap
4
+ from collections import ChainMap, defaultdict
5
5
  from collections.abc import Iterable, Mapping, Sequence
6
6
  from functools import partial
7
7
  from typing import Any, Literal
@@ -9,7 +9,7 @@ from typing import Any, Literal
9
9
  import torch
10
10
 
11
11
  from ...core import Chainable, Module, Var
12
- from ...utils import vec_to_tensors, set_storage_
12
+ from ...utils import set_storage_, vec_to_tensors
13
13
 
14
14
 
15
15
  def _make_projected_closure(closure, project_fn, unproject_fn,
@@ -166,7 +166,7 @@ class ProjectionBase(Module, ABC):
166
166
  current=current,
167
167
  ))
168
168
 
169
- projected_var = var.clone(clone_update=False)
169
+ projected_var = var.clone(clone_update=False, parent=var)
170
170
 
171
171
  closure = var.closure
172
172
 
@@ -278,7 +278,7 @@ class ProjectionBase(Module, ABC):
278
278
  unprojected_var = projected_var.clone(clone_update=False)
279
279
  unprojected_var.closure = var.closure
280
280
  unprojected_var.params = var.params
281
- unprojected_var.grad = var.grad
281
+ unprojected_var.grad = var.grad # this may also be set by projected_var since it has var as parent
282
282
 
283
283
  if self._project_update:
284
284
  assert projected_var.update is not None
@@ -1,14 +1,3 @@
1
- from .cg import (
2
- ConjugateDescent,
3
- DaiYuan,
4
- FletcherReeves,
5
- HagerZhang,
6
- HestenesStiefel,
7
- HybridHS_DY,
8
- LiuStorey,
9
- PolakRibiere,
10
- ProjectedGradientMethod,
11
- )
12
1
  from .diagonal_quasi_newton import (
13
2
  DNRTR,
14
3
  DiagonalBFGS,
@@ -19,9 +8,6 @@ from .diagonal_quasi_newton import (
19
8
  )
20
9
  from .lbfgs import LBFGS
21
10
  from .lsr1 import LSR1
22
- # from .olbfgs import OnlineLBFGS
23
-
24
- # from .experimental import ModularLBFGS
25
11
  from .quasi_newton import (
26
12
  BFGS,
27
13
  DFP,
@@ -40,7 +26,6 @@ from .quasi_newton import (
40
26
  NewSSM,
41
27
  Pearson,
42
28
  ProjectedNewtonRaphson,
43
- ThomasOptimalMethod,
44
29
  ShorR,
30
+ ThomasOptimalMethod,
45
31
  )
46
- from .trust_region import CubicRegularization, TrustCG, TrustRegionBase