torchzero 0.3.11__py3-none-any.whl → 0.3.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. tests/test_opts.py +95 -69
  2. tests/test_tensorlist.py +8 -7
  3. torchzero/__init__.py +1 -1
  4. torchzero/core/__init__.py +2 -2
  5. torchzero/core/module.py +225 -72
  6. torchzero/core/reformulation.py +65 -0
  7. torchzero/core/transform.py +44 -24
  8. torchzero/modules/__init__.py +13 -5
  9. torchzero/modules/{optimizers → adaptive}/__init__.py +5 -2
  10. torchzero/modules/adaptive/adagrad.py +356 -0
  11. torchzero/modules/{optimizers → adaptive}/adahessian.py +53 -52
  12. torchzero/modules/{optimizers → adaptive}/adam.py +0 -3
  13. torchzero/modules/{optimizers → adaptive}/adan.py +26 -40
  14. torchzero/modules/{optimizers → adaptive}/adaptive_heavyball.py +3 -6
  15. torchzero/modules/adaptive/aegd.py +54 -0
  16. torchzero/modules/{optimizers → adaptive}/esgd.py +1 -1
  17. torchzero/modules/{optimizers/ladagrad.py → adaptive/lmadagrad.py} +42 -39
  18. torchzero/modules/{optimizers → adaptive}/mars.py +24 -36
  19. torchzero/modules/adaptive/matrix_momentum.py +146 -0
  20. torchzero/modules/{optimizers → adaptive}/msam.py +14 -12
  21. torchzero/modules/{optimizers → adaptive}/muon.py +19 -20
  22. torchzero/modules/adaptive/natural_gradient.py +175 -0
  23. torchzero/modules/{optimizers → adaptive}/rprop.py +0 -2
  24. torchzero/modules/{optimizers → adaptive}/sam.py +1 -1
  25. torchzero/modules/{optimizers → adaptive}/shampoo.py +8 -4
  26. torchzero/modules/{optimizers → adaptive}/soap.py +27 -50
  27. torchzero/modules/{optimizers → adaptive}/sophia_h.py +2 -3
  28. torchzero/modules/clipping/clipping.py +85 -92
  29. torchzero/modules/clipping/ema_clipping.py +5 -5
  30. torchzero/modules/conjugate_gradient/__init__.py +11 -0
  31. torchzero/modules/{quasi_newton → conjugate_gradient}/cg.py +355 -369
  32. torchzero/modules/experimental/__init__.py +9 -32
  33. torchzero/modules/experimental/dct.py +2 -2
  34. torchzero/modules/experimental/fft.py +2 -2
  35. torchzero/modules/experimental/gradmin.py +4 -3
  36. torchzero/modules/experimental/l_infinity.py +111 -0
  37. torchzero/modules/{momentum/experimental.py → experimental/momentum.py} +3 -40
  38. torchzero/modules/experimental/newton_solver.py +79 -17
  39. torchzero/modules/experimental/newtonnewton.py +27 -14
  40. torchzero/modules/experimental/scipy_newton_cg.py +105 -0
  41. torchzero/modules/experimental/structural_projections.py +1 -1
  42. torchzero/modules/functional.py +50 -14
  43. torchzero/modules/grad_approximation/fdm.py +19 -20
  44. torchzero/modules/grad_approximation/forward_gradient.py +4 -2
  45. torchzero/modules/grad_approximation/grad_approximator.py +43 -47
  46. torchzero/modules/grad_approximation/rfdm.py +144 -122
  47. torchzero/modules/higher_order/__init__.py +1 -1
  48. torchzero/modules/higher_order/higher_order_newton.py +31 -23
  49. torchzero/modules/least_squares/__init__.py +1 -0
  50. torchzero/modules/least_squares/gn.py +161 -0
  51. torchzero/modules/line_search/__init__.py +2 -2
  52. torchzero/modules/line_search/_polyinterp.py +289 -0
  53. torchzero/modules/line_search/adaptive.py +69 -44
  54. torchzero/modules/line_search/backtracking.py +83 -70
  55. torchzero/modules/line_search/line_search.py +159 -68
  56. torchzero/modules/line_search/scipy.py +1 -1
  57. torchzero/modules/line_search/strong_wolfe.py +319 -218
  58. torchzero/modules/misc/__init__.py +8 -0
  59. torchzero/modules/misc/debug.py +4 -4
  60. torchzero/modules/misc/escape.py +9 -7
  61. torchzero/modules/misc/gradient_accumulation.py +88 -22
  62. torchzero/modules/misc/homotopy.py +59 -0
  63. torchzero/modules/misc/misc.py +82 -15
  64. torchzero/modules/misc/multistep.py +47 -11
  65. torchzero/modules/misc/regularization.py +5 -9
  66. torchzero/modules/misc/split.py +55 -35
  67. torchzero/modules/misc/switch.py +1 -1
  68. torchzero/modules/momentum/__init__.py +1 -5
  69. torchzero/modules/momentum/averaging.py +3 -3
  70. torchzero/modules/momentum/cautious.py +42 -47
  71. torchzero/modules/momentum/momentum.py +35 -1
  72. torchzero/modules/ops/__init__.py +9 -1
  73. torchzero/modules/ops/binary.py +9 -8
  74. torchzero/modules/{momentum/ema.py → ops/higher_level.py} +10 -33
  75. torchzero/modules/ops/multi.py +15 -15
  76. torchzero/modules/ops/reduce.py +1 -1
  77. torchzero/modules/ops/utility.py +12 -8
  78. torchzero/modules/projections/projection.py +4 -4
  79. torchzero/modules/quasi_newton/__init__.py +1 -16
  80. torchzero/modules/quasi_newton/damping.py +105 -0
  81. torchzero/modules/quasi_newton/diagonal_quasi_newton.py +167 -163
  82. torchzero/modules/quasi_newton/lbfgs.py +256 -200
  83. torchzero/modules/quasi_newton/lsr1.py +167 -132
  84. torchzero/modules/quasi_newton/quasi_newton.py +346 -446
  85. torchzero/modules/restarts/__init__.py +7 -0
  86. torchzero/modules/restarts/restars.py +252 -0
  87. torchzero/modules/second_order/__init__.py +2 -1
  88. torchzero/modules/second_order/multipoint.py +238 -0
  89. torchzero/modules/second_order/newton.py +133 -88
  90. torchzero/modules/second_order/newton_cg.py +141 -80
  91. torchzero/modules/smoothing/__init__.py +1 -1
  92. torchzero/modules/smoothing/sampling.py +300 -0
  93. torchzero/modules/step_size/__init__.py +1 -1
  94. torchzero/modules/step_size/adaptive.py +312 -47
  95. torchzero/modules/termination/__init__.py +14 -0
  96. torchzero/modules/termination/termination.py +207 -0
  97. torchzero/modules/trust_region/__init__.py +5 -0
  98. torchzero/modules/trust_region/cubic_regularization.py +170 -0
  99. torchzero/modules/trust_region/dogleg.py +92 -0
  100. torchzero/modules/trust_region/levenberg_marquardt.py +128 -0
  101. torchzero/modules/trust_region/trust_cg.py +97 -0
  102. torchzero/modules/trust_region/trust_region.py +350 -0
  103. torchzero/modules/variance_reduction/__init__.py +1 -0
  104. torchzero/modules/variance_reduction/svrg.py +208 -0
  105. torchzero/modules/weight_decay/weight_decay.py +65 -64
  106. torchzero/modules/zeroth_order/__init__.py +1 -0
  107. torchzero/modules/zeroth_order/cd.py +359 -0
  108. torchzero/optim/root.py +65 -0
  109. torchzero/optim/utility/split.py +8 -8
  110. torchzero/optim/wrappers/directsearch.py +0 -1
  111. torchzero/optim/wrappers/fcmaes.py +3 -2
  112. torchzero/optim/wrappers/nlopt.py +0 -2
  113. torchzero/optim/wrappers/optuna.py +2 -2
  114. torchzero/optim/wrappers/scipy.py +81 -22
  115. torchzero/utils/__init__.py +40 -4
  116. torchzero/utils/compile.py +1 -1
  117. torchzero/utils/derivatives.py +123 -111
  118. torchzero/utils/linalg/__init__.py +9 -2
  119. torchzero/utils/linalg/linear_operator.py +329 -0
  120. torchzero/utils/linalg/matrix_funcs.py +2 -2
  121. torchzero/utils/linalg/orthogonalize.py +2 -1
  122. torchzero/utils/linalg/qr.py +2 -2
  123. torchzero/utils/linalg/solve.py +226 -154
  124. torchzero/utils/metrics.py +83 -0
  125. torchzero/utils/python_tools.py +6 -0
  126. torchzero/utils/tensorlist.py +105 -34
  127. torchzero/utils/torch_tools.py +9 -4
  128. torchzero-0.3.13.dist-info/METADATA +14 -0
  129. torchzero-0.3.13.dist-info/RECORD +166 -0
  130. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/top_level.txt +0 -1
  131. docs/source/conf.py +0 -59
  132. docs/source/docstring template.py +0 -46
  133. torchzero/modules/experimental/absoap.py +0 -253
  134. torchzero/modules/experimental/adadam.py +0 -118
  135. torchzero/modules/experimental/adamY.py +0 -131
  136. torchzero/modules/experimental/adam_lambertw.py +0 -149
  137. torchzero/modules/experimental/adaptive_step_size.py +0 -90
  138. torchzero/modules/experimental/adasoap.py +0 -177
  139. torchzero/modules/experimental/cosine.py +0 -214
  140. torchzero/modules/experimental/cubic_adam.py +0 -97
  141. torchzero/modules/experimental/eigendescent.py +0 -120
  142. torchzero/modules/experimental/etf.py +0 -195
  143. torchzero/modules/experimental/exp_adam.py +0 -113
  144. torchzero/modules/experimental/expanded_lbfgs.py +0 -141
  145. torchzero/modules/experimental/hnewton.py +0 -85
  146. torchzero/modules/experimental/modular_lbfgs.py +0 -265
  147. torchzero/modules/experimental/parabolic_search.py +0 -220
  148. torchzero/modules/experimental/subspace_preconditioners.py +0 -145
  149. torchzero/modules/experimental/tensor_adagrad.py +0 -42
  150. torchzero/modules/line_search/polynomial.py +0 -233
  151. torchzero/modules/momentum/matrix_momentum.py +0 -193
  152. torchzero/modules/optimizers/adagrad.py +0 -165
  153. torchzero/modules/quasi_newton/trust_region.py +0 -397
  154. torchzero/modules/smoothing/gaussian.py +0 -198
  155. torchzero-0.3.11.dist-info/METADATA +0 -404
  156. torchzero-0.3.11.dist-info/RECORD +0 -159
  157. torchzero-0.3.11.dist-info/licenses/LICENSE +0 -21
  158. /torchzero/modules/{optimizers → adaptive}/lion.py +0 -0
  159. /torchzero/modules/{optimizers → adaptive}/orthograd.py +0 -0
  160. /torchzero/modules/{optimizers → adaptive}/rmsprop.py +0 -0
  161. {torchzero-0.3.11.dist-info → torchzero-0.3.13.dist-info}/WHEEL +0 -0
@@ -0,0 +1,146 @@
1
+ from typing import Literal
2
+ from collections.abc import Callable
3
+ import torch
4
+
5
+ from ...core import Module, apply_transform, Chainable
6
+ from ...utils import NumberList, TensorList, as_tensorlist
7
+ from ...utils.derivatives import hvp, hvp_fd_central, hvp_fd_forward
8
+ from ..functional import initial_step_size
9
+
10
+
11
+ class MatrixMomentum(Module):
12
+ """Second order momentum method.
13
+
14
+ Matrix momentum is useful for convex objectives, also for some reason it has very really good generalization on elastic net logistic regression.
15
+
16
+ Notes:
17
+ - ``mu`` needs to be tuned very carefully. It is supposed to be smaller than (1/largest eigenvalue), otherwise this will be very unstable. I have devised an adaptive version of this - ``tz.m.AdaptiveMatrixMomentum``, and it works well without having to tune ``mu``, however the adaptive version doesn't work on stochastic objectives.
18
+
19
+ - In most cases ``MatrixMomentum`` should be the first module in the chain because it relies on autograd.
20
+
21
+ - This module requires the a closure passed to the optimizer step, as it needs to re-evaluate the loss and gradients for calculating HVPs. The closure must accept a ``backward`` argument.
22
+
23
+ Args:
24
+ mu (float, optional): this has a similar role to (1 - beta) in normal momentum. Defaults to 0.1.
25
+ hvp_method (str, optional):
26
+ Determines how Hessian-vector products are evaluated.
27
+
28
+ - ``"autograd"``: Use PyTorch's autograd to calculate exact HVPs.
29
+ This requires creating a graph for the gradient.
30
+ - ``"forward"``: Use a forward finite difference formula to
31
+ approximate the HVP. This requires one extra gradient evaluation.
32
+ - ``"central"``: Use a central finite difference formula for a
33
+ more accurate HVP approximation. This requires two extra
34
+ gradient evaluations.
35
+ Defaults to "autograd".
36
+ h (float, optional): finite difference step size if hvp_method is set to finite difference. Defaults to 1e-3.
37
+ hvp_tfm (Chainable | None, optional): optional module applied to hessian-vector products. Defaults to None.
38
+
39
+ Reference:
40
+ Orr, Genevieve, and Todd Leen. "Using curvature information for fast stochastic search." Advances in neural information processing systems 9 (1996).
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ lr:float,
46
+ mu=0.1,
47
+ hvp_method: Literal["autograd", "forward", "central"] = "autograd",
48
+ h: float = 1e-3,
49
+ adaptive:bool = False,
50
+ adapt_freq: int | None = None,
51
+ hvp_tfm: Chainable | None = None,
52
+ ):
53
+ defaults = dict(lr=lr, mu=mu, hvp_method=hvp_method, h=h, adaptive=adaptive, adapt_freq=adapt_freq)
54
+ super().__init__(defaults)
55
+
56
+ if hvp_tfm is not None:
57
+ self.set_child('hvp_tfm', hvp_tfm)
58
+
59
+ def reset_for_online(self):
60
+ super().reset_for_online()
61
+ self.clear_state_keys('p_prev')
62
+
63
+ @torch.no_grad
64
+ def update(self, var):
65
+ assert var.closure is not None
66
+ p = TensorList(var.params)
67
+ p_prev = self.get_state(p, 'p_prev', init=var.params)
68
+
69
+ hvp_method = self.defaults['hvp_method']
70
+ h = self.defaults['h']
71
+ step = self.global_state.get("step", 0)
72
+ self.global_state["step"] = step + 1
73
+
74
+ if step > 0:
75
+ s = p - p_prev
76
+
77
+ Hs, _ = self.Hvp(s, at_x0=True, var=var, rgrad=None, hvp_method=hvp_method, h=h, normalize=True, retain_grad=False)
78
+ Hs = [t.detach() for t in Hs]
79
+
80
+ if 'hvp_tfm' in self.children:
81
+ Hs = TensorList(apply_transform(self.children['hvp_tfm'], Hs, params=p, grads=var.grad, var=var))
82
+
83
+ self.store(p, ("Hs", "s"), (Hs, s))
84
+
85
+ # -------------------------------- adaptive mu ------------------------------- #
86
+ if self.defaults["adaptive"]:
87
+ g = TensorList(var.get_grad())
88
+
89
+ if self.defaults["adapt_freq"] is None:
90
+ # ---------------------------- deterministic case ---------------------------- #
91
+ g_prev = self.get_state(var.params, "g_prev", cls=TensorList)
92
+ y = g - g_prev
93
+ g_prev.copy_(g)
94
+ denom = y.global_vector_norm()
95
+ denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
96
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
97
+
98
+ else:
99
+ # -------------------------------- stochastic -------------------------------- #
100
+ adapt_freq = self.defaults["adapt_freq"]
101
+
102
+ # we start on 1nd step, and want to adapt when we start, so use (step - 1)
103
+ if (step - 1) % adapt_freq == 0:
104
+ assert var.closure is not None
105
+ params = TensorList(var.params)
106
+ p_cur = params.clone()
107
+
108
+ # move to previous params and evaluate p_prev with current mini-batch
109
+ params.copy_(self.get_state(var.params, 'p_prev'))
110
+ with torch.enable_grad():
111
+ var.closure()
112
+ g_prev = [p.grad if p.grad is not None else torch.zeros_like(p) for p in params]
113
+ y = g - g_prev
114
+
115
+ # move back to current params
116
+ params.copy_(p_cur)
117
+
118
+ denom = y.global_vector_norm()
119
+ denom = denom.clip(min=torch.finfo(denom.dtype).tiny * 2)
120
+ self.global_state["mu_mul"] = s.global_vector_norm() / denom
121
+
122
+ torch._foreach_copy_(p_prev, var.params)
123
+
124
+ @torch.no_grad
125
+ def apply(self, var):
126
+ update = TensorList(var.get_update())
127
+ lr,mu = self.get_settings(var.params, "lr", 'mu', cls=NumberList)
128
+
129
+ if "mu_mul" in self.global_state:
130
+ mu = mu * self.global_state["mu_mul"]
131
+
132
+ # --------------------------------- 1st step --------------------------------- #
133
+ # p_prev is not available so make a small step
134
+ step = self.global_state["step"]
135
+ if step == 1:
136
+ if self.defaults["adaptive"]: self.get_state(var.params, "g_prev", init=var.get_grad())
137
+ update.mul_(lr) # separate so that initial_step_size can clip correctly
138
+ update.mul_(initial_step_size(update, 1e-7))
139
+ return var
140
+
141
+ # -------------------------- matrix momentum update -------------------------- #
142
+ s, Hs = self.get_state(var.params, 's', 'Hs', cls=TensorList)
143
+
144
+ update.mul_(lr).sub_(s).add_(Hs*mu)
145
+ var.update = update
146
+ return var
@@ -42,13 +42,15 @@ def msam_(
42
42
  # can't really decouple it from lr
43
43
  # but at least it is now expressed as function of g
44
44
 
45
- denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
45
+ denom = velocity_.global_vector_norm() / rho
46
+ denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
46
47
  vn = velocity_ / denom
47
48
 
48
49
  mom_ = nag_ if nesterov else ema_
49
50
  velocity_ = mom_(tensors, velocity_, momentum, dampening=0, lerp=lerp)
50
51
 
51
- denom = (velocity_.global_vector_norm() / rho).clip(min=1e-8)
52
+ denom = velocity_.global_vector_norm() / rho
53
+ denom = denom.clip(min=torch.finfo(tensors[0].dtype).tiny * 2)
52
54
  v1n = velocity_ / denom
53
55
 
54
56
  if inner is not None:
@@ -74,11 +76,11 @@ class MSAM(Transform):
74
76
  replacement for momentum strategies in other optimizers.
75
77
 
76
78
  To combine MSAM with other optimizers in the way done in the official implementation,
77
- e.g. to make Adam_MSAM, use :code:`tz.m.MSAMObjective` module.
79
+ e.g. to make Adam_MSAM, use ``tz.m.MSAMObjective`` module.
78
80
 
79
- .. note::
81
+ Note
80
82
  MSAM has a learning rate hyperparameter that can't really be removed from the update rule.
81
- To avoid compounding learning rate mofications, remove the :code:`tz.m.LR` module if you had it.
83
+ To avoid compounding learning rate mofications, remove the ``tz.m.LR`` module if you had it.
82
84
 
83
85
  Args:
84
86
  lr (float): learning rate. Adding this module adds support for learning rate schedulers.
@@ -112,10 +114,10 @@ class MSAM(Transform):
112
114
  tz.m.Debias(0.9, 0.999),
113
115
  )
114
116
  """
115
- USES_LR = True
117
+ _USES_LR = True
116
118
  def __init__(self, lr: float, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False,):
117
119
  defaults = dict(momentum=momentum,rho=rho, nesterov=nesterov, lerp=lerp, weight_decay=weight_decay)
118
- if self.USES_LR: defaults['lr'] = lr
120
+ if self._USES_LR: defaults['lr'] = lr
119
121
  super().__init__(defaults, uses_grad=False)
120
122
 
121
123
  @torch.no_grad
@@ -125,7 +127,7 @@ class MSAM(Transform):
125
127
  lerp = s['lerp']
126
128
  nesterov = s['nesterov']
127
129
 
128
- if self.USES_LR:
130
+ if self._USES_LR:
129
131
  lr, momentum, rho, weight_decay = unpack_dicts(settings, 'lr','momentum','rho','weight_decay', cls=NumberList)
130
132
 
131
133
  else:
@@ -152,9 +154,9 @@ class MSAM(Transform):
152
154
  class MSAMObjective(MSAM):
153
155
  """Momentum-SAM from https://arxiv.org/pdf/2401.12033.
154
156
 
155
- .. note::
156
- Please make sure to place :code:`tz.m.LR` inside the :code:`modules` argument. For example,
157
- :code:`tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])`. Putting LR after MSAM will lead
157
+ Note:
158
+ Please make sure to place ``tz.m.LR`` inside the ``modules`` argument. For example,
159
+ ``tz.m.MSAMObjective([tz.m.Adam(), tz.m.LR(1e-3)])``. Putting LR after MSAM will lead
158
160
  to an incorrect update rule.
159
161
 
160
162
  Args:
@@ -179,7 +181,7 @@ class MSAMObjective(MSAM):
179
181
  )
180
182
  )
181
183
  """
182
- USES_LR = False
184
+ _USES_LR = False
183
185
  def __init__(self, modules: Chainable, momentum:float=0.9, rho:float=0.3, weight_decay:float=0, nesterov=False, lerp=False):
184
186
  super().__init__(lr=0, momentum=momentum, rho=rho, weight_decay=weight_decay, nesterov=nesterov, lerp=lerp)
185
187
  self.set_child('modules', modules)
@@ -167,26 +167,25 @@ class Orthogonalize(TensorwiseTransform):
167
167
  target (str, optional):
168
168
  what to set on var.
169
169
 
170
-
171
- Examples:
172
- standard Muon with Adam fallback
173
-
174
- .. code-block:: python
175
-
176
- opt = tz.Modular(
177
- model.head.parameters(),
178
- tz.m.Split(
179
- # apply muon only to 2D+ parameters
180
- filter = lambda t: t.ndim >= 2,
181
- true = [
182
- tz.m.HeavyBall(),
183
- tz.m.Orthogonalize(),
184
- tz.m.LR(1e-2),
185
- ],
186
- false = tz.m.Adam()
187
- ),
188
- tz.m.LR(1e-2)
189
- )
170
+ ## Examples:
171
+
172
+ standard Muon with Adam fallback
173
+ ```py
174
+ opt = tz.Modular(
175
+ model.head.parameters(),
176
+ tz.m.Split(
177
+ # apply muon only to 2D+ parameters
178
+ filter = lambda t: t.ndim >= 2,
179
+ true = [
180
+ tz.m.HeavyBall(),
181
+ tz.m.Orthogonalize(),
182
+ tz.m.LR(1e-2),
183
+ ],
184
+ false = tz.m.Adam()
185
+ ),
186
+ tz.m.LR(1e-2)
187
+ )
188
+ ```
190
189
 
191
190
  Reference:
192
191
  Keller Jordan, Yuchen Jin, Vlado Boza, You Jiacheng, Franz Cesista, Laker Newhouse, Jeremy Bernstein - Muon: An optimizer for hidden layers in neural networks (2024) https://github.com/KellerJordan/Muon
@@ -0,0 +1,175 @@
1
+ import torch
2
+ from ...core import Module, Chainable, apply_transform
3
+
4
+ from ...utils.derivatives import jacobian_wrt, flatten_jacobian
5
+ from ...utils import vec_to_tensors, TensorList
6
+ from ...utils.linalg import linear_operator
7
+ from .lmadagrad import lm_adagrad_apply, lm_adagrad_update
8
+
9
+ class NaturalGradient(Module):
10
+ """Natural gradient approximated via empirical fisher information matrix.
11
+
12
+ To use this, either pass vector of per-sample losses to the step method, or make sure
13
+ the closure returns it. Gradients will be calculated via batched autograd within this module,
14
+ you don't need to implement the backward pass. When using closure, please add the ``backward`` argument,
15
+ it will always be False but it is required. See below for an example.
16
+
17
+ Note:
18
+ Empirical fisher information matrix may give a really bad approximation in some cases.
19
+ If that is the case, set ``sqrt`` to True to perform whitening instead, which is way more robust.
20
+
21
+ Args:
22
+ reg (float, optional): regularization parameter. Defaults to 1e-8.
23
+ sqrt (bool, optional):
24
+ if True, uses square root of empirical fisher information matrix. Both EFIM and it's square
25
+ root can be calculated and stored efficiently without ndim^2 memory. Square root
26
+ whitens the gradient and often performs much better, especially when you try to use NGD
27
+ with a vector that isn't strictly per-sample gradients, but rather for example different losses.
28
+ gn_grad (bool, optional):
29
+ if True, uses Gauss-Newton G^T @ f as the gradient, which is effectively sum weighted by value
30
+ and is equivalent to squaring the values. This way you can solve least-squares
31
+ objectives with a NGD-like algorithm. If False, uses sum of per-sample gradients.
32
+ This has an effect when ``sqrt=True``, and affects the ``grad`` attribute.
33
+ Defaults to False.
34
+ batched (bool, optional): whether to use vmapping. Defaults to True.
35
+
36
+ Examples:
37
+
38
+ training a neural network:
39
+ ```python
40
+ X = torch.randn(64, 20)
41
+ y = torch.randn(64, 10)
42
+
43
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
44
+ opt = tz.Modular(
45
+ model.parameters(),
46
+ tz.m.NaturalGradient(),
47
+ tz.m.LR(3e-2)
48
+ )
49
+
50
+ for i in range(100):
51
+ y_hat = model(X) # (64, 10)
52
+ losses = (y_hat - y).pow(2).mean(0) # (10, )
53
+ opt.step(loss=losses)
54
+ if i % 10 == 0:
55
+ print(f'{losses.mean() = }')
56
+ ```
57
+
58
+ training a neural network - closure version
59
+ ```python
60
+ X = torch.randn(64, 20)
61
+ y = torch.randn(64, 10)
62
+
63
+ model = nn.Sequential(nn.Linear(20, 64), nn.ELU(), nn.Linear(64, 10))
64
+ opt = tz.Modular(
65
+ model.parameters(),
66
+ tz.m.NaturalGradient(),
67
+ tz.m.LR(3e-2)
68
+ )
69
+
70
+ def closure(backward=True):
71
+ y_hat = model(X) # (64, 10)
72
+ return (y_hat - y).pow(2).mean(0) # (10, )
73
+
74
+ for i in range(100):
75
+ losses = opt.step(closure)
76
+ if i % 10 == 0:
77
+ print(f'{losses.mean() = }')
78
+ ```
79
+
80
+ minimizing the rosenbrock function with a mix of natural gradient, whitening and gauss-newton:
81
+ ```python
82
+ def rosenbrock(X):
83
+ x1, x2 = X
84
+ return torch.stack([(1 - x1).abs(), (10 * (x2 - x1**2).abs())])
85
+
86
+ X = torch.tensor([-1.1, 2.5], requires_grad=True)
87
+ opt = tz.Modular([X], tz.m.NaturalGradient(sqrt=True, gn_grad=True), tz.m.LR(0.05))
88
+
89
+ for iter in range(200):
90
+ losses = rosenbrock(X)
91
+ opt.step(loss=losses)
92
+ if iter % 20 == 0:
93
+ print(f'{losses.mean() = }')
94
+ ```
95
+ """
96
+ def __init__(self, reg:float = 1e-8, sqrt:bool=False, gn_grad:bool=False, batched:bool=True, ):
97
+ super().__init__(defaults=dict(batched=batched, reg=reg, sqrt=sqrt, gn_grad=gn_grad))
98
+
99
+ @torch.no_grad
100
+ def update(self, var):
101
+ params = var.params
102
+ batched = self.defaults['batched']
103
+ gn_grad = self.defaults['gn_grad']
104
+
105
+ closure = var.closure
106
+ assert closure is not None
107
+
108
+ with torch.enable_grad():
109
+ f = var.get_loss(backward=False) # n_out
110
+ assert isinstance(f, torch.Tensor)
111
+ G_list = jacobian_wrt([f.ravel()], params, batched=batched)
112
+
113
+ var.loss = f.sum()
114
+ G = self.global_state["G"] = flatten_jacobian(G_list) # (n_samples, ndim)
115
+
116
+ if gn_grad:
117
+ g = self.global_state["g"] = G.H @ f.detach()
118
+
119
+ else:
120
+ g = self.global_state["g"] = G.sum(0)
121
+
122
+ var.grad = vec_to_tensors(g, params)
123
+
124
+ # set closure to calculate scalar value for line searches etc
125
+ if var.closure is not None:
126
+ def ngd_closure(backward=True):
127
+ if backward:
128
+ var.zero_grad()
129
+ with torch.enable_grad():
130
+ loss = closure(False)
131
+ if gn_grad: loss = loss.pow(2)
132
+ loss = loss.sum()
133
+ loss.backward()
134
+ return loss
135
+
136
+ loss = closure(False)
137
+ if gn_grad: loss = loss.pow(2)
138
+ return loss.sum()
139
+
140
+ var.closure = ngd_closure
141
+
142
+ @torch.no_grad
143
+ def apply(self, var):
144
+ params = var.params
145
+ reg = self.defaults['reg']
146
+ sqrt = self.defaults['sqrt']
147
+
148
+ G: torch.Tensor = self.global_state['G'] # (n_samples, n_dim)
149
+
150
+ if sqrt:
151
+ # this computes U, S <- SVD(M), then calculate update as U S^-1 Uᵀg,
152
+ # but it computes it through eigendecompotision
153
+ U, L = lm_adagrad_update(G.H, reg, 0)
154
+ if U is None or L is None: return var
155
+
156
+ v = lm_adagrad_apply(self.global_state["g"], U, L)
157
+ var.update = vec_to_tensors(v, params)
158
+ return var
159
+
160
+ GGT = G @ G.H # (n_samples, n_samples)
161
+
162
+ if reg != 0:
163
+ GGT.add_(torch.eye(GGT.size(0), device=GGT.device, dtype=GGT.dtype).mul_(reg))
164
+
165
+ z, _ = torch.linalg.solve_ex(GGT, torch.ones_like(GGT[0])) # pylint:disable=not-callable
166
+ v = G.H @ z
167
+
168
+ var.update = vec_to_tensors(v, params)
169
+ return var
170
+
171
+
172
+ def get_H(self, var):
173
+ if "G" not in self.global_state: return linear_operator.ScaledIdentity()
174
+ G = self.global_state['G']
175
+ return linear_operator.AtA(G)
@@ -258,8 +258,6 @@ class BacktrackOnSignChange(Transform):
258
258
  This is part of RProp update rule.
259
259
 
260
260
  Args:
261
- normalize (bool, optional): renormalize update after masking. Defaults to False.
262
- eps (_type_, optional): epsilon for normalization. Defaults to 1e-6.
263
261
  use_grad (bool, optional):
264
262
  if True, tracks sign change of the gradient,
265
263
  otherwise track sign change of the update. Defaults to True.
@@ -63,7 +63,7 @@ class SAM(Module):
63
63
  zero_grad = var.zero_grad
64
64
  if closure is None: raise RuntimeError("SAM requires a closure passed to the optimizer step")
65
65
  p, rho = self.get_settings(var.params, 'p', 'rho', cls=NumberList)
66
- s = self.settings[var.params[0]]
66
+ s = self.defaults
67
67
  eps = s['eps']
68
68
  asam = s['asam']
69
69
 
@@ -17,6 +17,7 @@ def update_shampoo_preconditioner_(
17
17
  update_freq: int,
18
18
  exp_override: int | None,
19
19
  beta: float | None,
20
+ reg: float
20
21
  ):
21
22
  for i, (accumulator, preconditioner) in enumerate(zip(accumulators_, preconditioners_)):
22
23
  if accumulator is None: continue
@@ -28,6 +29,8 @@ def update_shampoo_preconditioner_(
28
29
 
29
30
  if step % update_freq == 0:
30
31
  matrix_exp = -1/(grad.ndim*2) if exp_override is None else -1/exp_override
32
+ if reg != 0:
33
+ accumulator = accumulator + torch.eye(accumulator.size(0), device=accumulator.device, dtype=accumulator.dtype).mul_(reg)
31
34
  set_storage_(preconditioner, matrix_power_eigh(accumulator, matrix_exp))
32
35
 
33
36
 
@@ -99,7 +102,6 @@ class Shampoo(Transform):
99
102
  decay (float | None, optional): slowly decays preconditioners. Defaults to None.
100
103
  beta (float | None, optional):
101
104
  if None calculates sum as in standard shampoo, otherwise uses EMA of preconditioners. Defaults to None.
102
- matrix_eps (float, optional): epsilon for matrix operations. Defaults to 1e-10.
103
105
  update_freq (int, optional): preconditioner update frequency. Defaults to 10.
104
106
  exp_override (int | None, optional): matrix exponent override, if not set, uses 2*ndim. Defaults to 2.
105
107
  merge_small (bool, optional): whether to merge small dims on tensors. Defaults to True.
@@ -140,6 +142,7 @@ class Shampoo(Transform):
140
142
  self,
141
143
  decay: float | None = None,
142
144
  beta: float | None = None,
145
+ reg: float = 1e-12,
143
146
  update_freq: int = 10,
144
147
  exp_override: int | None = 2,
145
148
  merge_small: bool = True,
@@ -148,7 +151,7 @@ class Shampoo(Transform):
148
151
  adagrad_eps: float = 1e-8,
149
152
  inner: Chainable | None = None,
150
153
  ):
151
- defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps)
154
+ defaults = dict(decay=decay, beta=beta, update_freq=update_freq, exp_override=exp_override, merge_small=merge_small, max_dim=max_dim, precondition_1d=precondition_1d,adagrad_eps=adagrad_eps, reg=reg)
152
155
  super().__init__(defaults, uses_grad=False)
153
156
 
154
157
  if inner is not None:
@@ -159,8 +162,8 @@ class Shampoo(Transform):
159
162
 
160
163
  # update preconditioners
161
164
  for i,(t,state, setting) in enumerate(zip(tensors, states, settings)):
162
- beta, update_freq, exp_override, merge_small, max_dim, precondition_1d = itemgetter(
163
- 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d')(setting)
165
+ beta, update_freq, exp_override, merge_small, max_dim, precondition_1d, reg = itemgetter(
166
+ 'beta', 'update_freq', 'exp_override', 'merge_small', 'max_dim', 'precondition_1d', "reg")(setting)
164
167
 
165
168
  if merge_small:
166
169
  t, state['flat_sizes'], state['sort_idxs'] = _merge_small_dims(t, max_dim)
@@ -195,6 +198,7 @@ class Shampoo(Transform):
195
198
  update_freq=update_freq,
196
199
  exp_override=exp_override,
197
200
  beta=beta,
201
+ reg=reg,
198
202
  )
199
203
 
200
204
  # inner step
@@ -1,9 +1,10 @@
1
1
  from operator import itemgetter
2
+ import warnings
2
3
 
3
4
  import torch
4
5
 
5
6
  from ...core import Chainable, Transform, apply_transform
6
- from ...modules.optimizers.shampoo import _merge_small_dims, _unmerge_small_dims
7
+ from ...modules.adaptive.shampoo import _merge_small_dims, _unmerge_small_dims
7
8
 
8
9
  @torch.no_grad
9
10
  def update_soap_covariances_(
@@ -52,36 +53,23 @@ def get_orthogonal_matrix(mat: list[torch.Tensor | None]):
52
53
  """
53
54
  Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
54
55
  """
55
- matrix = []
56
- float_data = False
57
- original_type = original_device = None
58
- for m in mat:
59
- if m is None or len(m) == 0:
60
- matrix.append([])
61
- continue
62
- if m.dtype != torch.float:
63
- original_type = m.dtype
64
- original_device = m.device
65
- matrix.append(m.float())
66
- else:
67
- float_data = True
68
- matrix.append(m)
69
56
 
70
57
  final = []
71
- for m in matrix:
72
- if len(m) == 0:
58
+ for m in mat:
59
+
60
+ if m is None or len(m) == 0:
73
61
  final.append([])
74
62
  continue
63
+
75
64
  try:
76
65
  _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
77
- except Exception:
66
+ except torch.linalg.LinAlgError:
78
67
  _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device)) # pylint:disable=not-callable
79
68
  Q = Q.to(m.dtype)
80
- Q = torch.flip(Q, [1])
81
69
 
82
- if not float_data:
83
- Q = Q.to(original_device).type(original_type)
70
+ Q = torch.flip(Q, [1])
84
71
  final.append(Q)
72
+
85
73
  return final
86
74
 
87
75
  # function from https://github.com/nikhilvyas/SOAP/blob/main/soap.py#L240
@@ -91,40 +79,24 @@ def get_orthogonal_matrix_QR(exp_avg_sq: torch.Tensor, GG: list[torch.Tensor | N
91
79
  Computes the eigenbases of the preconditioner using one round of power iteration
92
80
  followed by torch.linalg.qr decomposition.
93
81
  """
94
- matrix = []
95
- orth_matrix = []
96
- float_data = False
97
- original_type = original_device = None
98
- for m,o in zip(GG, Q_list):
82
+ final = []
83
+
84
+ for ind, (m,o) in enumerate(zip(GG, Q_list)):
85
+
86
+ # skip 1d or large dims
99
87
  if m is None or len(m) == 0:
100
- matrix.append([])
101
- orth_matrix.append([])
88
+ final.append([])
102
89
  continue
103
90
  assert o is not None
104
- if m.data.dtype != torch.float:
105
- original_type = m.data.dtype
106
- original_device = m.data.device
107
- matrix.append(m.data.float())
108
- orth_matrix.append(o.data.float())
109
- else:
110
- float_data = True
111
- matrix.append(m.data.float())
112
- orth_matrix.append(o.data.float())
113
91
 
114
- final = []
115
- for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
116
- if len(m)==0:
117
- final.append([])
118
- continue
119
92
  est_eig = torch.diag(o.T @ m @ o)
120
93
  sort_idx = torch.argsort(est_eig, descending=True)
121
94
  exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
122
- o = o[:,sort_idx]
123
- power_iter = m @ o
124
- Q, _ = torch.linalg.qr(power_iter) # pylint:disable=not-callable
125
95
 
126
- if not float_data:
127
- Q = Q.to(original_device).type(original_type)
96
+ power_iter = m @ o[:, sort_idx]
97
+ Q, _ = torch.linalg.qr(power_iter.to(torch.float32)) # pylint:disable=not-callable
98
+ Q = Q.to(power_iter.dtype)
99
+
128
100
  final.append(Q)
129
101
 
130
102
  return final, exp_avg_sq
@@ -226,7 +198,10 @@ class SOAP(Transform):
226
198
 
227
199
  if state['GG'] is not None:
228
200
  update_soap_covariances_(t, GGs_=state['GG'], beta=shampoo_beta)
229
- state['Q'] = get_orthogonal_matrix(state['GG'])
201
+ try: state['Q'] = get_orthogonal_matrix(state['GG'])
202
+ except torch.linalg.LinAlgError as e:
203
+ warnings.warn(f"torch.linalg.eigh raised an error when initializing SOAP Q matrices on 1st step, diagonal preconditioning will be used for this parameter. The error was:\n{e}")
204
+ state["GG"] = None
230
205
 
231
206
  state['step'] = 0
232
207
  updates.append(tensors[i].clip(-0.1, 0.1))
@@ -283,6 +258,8 @@ class SOAP(Transform):
283
258
  if state['GG'] is not None:
284
259
  update_soap_covariances_(t, state['GG'], shampoo_beta)
285
260
  if state['step'] % setting['precond_freq'] == 0:
286
- state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
287
-
261
+ try:
262
+ state['Q'], state['exp_avg_sq_projected'] = get_orthogonal_matrix_QR(exp_avg_sq_projected, state['GG'], state['Q'])
263
+ except torch.linalg.LinAlgError:
264
+ pass
288
265
  return updates
@@ -4,8 +4,6 @@ import torch
4
4
 
5
5
  from ...core import Module, Target, Transform, Chainable, apply_transform
6
6
  from ...utils import NumberList, TensorList, as_tensorlist
7
- from ...utils.derivatives import hvp, hvp_fd_forward, hvp_fd_central
8
-
9
7
  def sophia_H(
10
8
  tensors: TensorList,
11
9
  h: TensorList | None,
@@ -72,7 +70,7 @@ class SophiaH(Module):
72
70
  more accurate HVP approximation. This requires two extra
73
71
  gradient evaluations.
74
72
  Defaults to "autograd".
75
- h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
73
+ fd_h (float, optional): finite difference step size if :code:`hvp_method` is "forward" or "central". Defaults to 1e-3.
76
74
  n_samples (int, optional):
77
75
  number of hessian-vector products with random vectors to evaluate each time when updating
78
76
  the preconditioner. Larger values may lead to better hessian diagonal estimate. Defaults to 1.
@@ -159,6 +157,7 @@ class SophiaH(Module):
159
157
 
160
158
  Hvp, rgrad = self.Hvp(u, at_x0=True, var=var, rgrad=rgrad, hvp_method=hvp_method,
161
159
  h=fd_h, normalize=True, retain_grad=i < n_samples-1)
160
+ Hvp = tuple(Hvp)
162
161
 
163
162
  if h is None: h = Hvp
164
163
  else: torch._foreach_add_(h, Hvp)